Skip to content

ShapeGuard is a small tool to help with handling shapes in Pytorch and Tensorflow.

License

Notifications You must be signed in to change notification settings

indigoviolet/shapeguard

 
 

Repository files navigation

ShapeGuard

ShapeGuard is a tool to help with handling shapes in Pytorch, Tensorflow and NumPy.

Attribution

This project was originally forked from https://github.com/Qwlouse/shapeguard, and I have made a number of interface changes/improvements/updates to it.

Convenient guarding

from shapeguard import sg
t = [[1],[2]]
sg(t, 'B,W'); sg.get()
sg(t, 'B,W, W')

This is ok, except for the annoyance of having to import shapeguard everywhere. You can get around this by doing sg.install(), which will pretend that sg() is a builtin.

(mypy still won’t be happy though, since sg() is undeclared)

sg.install()

Forking

sg(x, "B,C,H,W")

# make a fork of the known dims by passing some params
with sg.fork(stride=2):

    # You can use the global dims here, but also infer some dims that
    # are only relevant in this fork (ie. for stride=2)
    sg(y, "B,C,h,w")

    # You can pass up some newly inferred dims that are globally
    # relevant by uppercasing them
    sg(z, "A,2")

# A is available in the global namespace now, but h,w are not
sg(xx, "A,B,C")

## Forks are reusable
with sg.fork(stride=2):
    # you can reuse the dims saved for this fork by passing the same
    # fork params
    sg(yy, "h,w,3")


## throwaway forks: these will be created anew each time (but upper case dims will still propagate to base)
with sg.fork():
    sg(yy, "h, w, 3")

## drop dims inside the fork. This can be useful to handle the last batch where size is different
with sg.fork() as batch_sg:
    sg.drop("B")

No-op mode

with sg.noop():
    # No guards will be executed in this block
    ...

Shape Template Syntax

The shape template mini-DSL supports many different ways of specifying shapes:

  • numbers: 64, 32, 32, 3.0
  • named dimensions: B, width, height2, channels
  • assignment to names that can then be used in further guards: B, W2=W/2, H, C
  • wildcards: B, *, *, *
  • ellipsis: B, ..., 3
  • addition, subtraction, multiplication, division: B*N, W/2, H*(C+1)
  • comment-only dimensions: ?,_num_targets,W,C (num_targets won’t be stored for future)
  • dynamic dimensions (tensorflow): ?, H, W, C (only matches [None, H, W, C])

Changes

  • removed generated lark parser in favor of using the library directly
  • removed dotted access to dims from ShapeGuard
  • cleaned up Python types
  • added a shim system to make it extensible
  • added support for Pytorch tensors
  • added `sg()` and `Tensor.sg“ patching for Pytorch
  • added support for list templates in `sg()`
  • added support for assignment in shape templates
  • exceptions from sg() display the original line
  • Added sg.fork(…) context manager
  • Added sg.noop() context manager
  • Added sg.install()
  • Support floats in templates (to handle `2.0` etc from interpolation)

ShapeGuard() usage

This section describes how to use the underlying ShapeGuard class, which you probably don’t need to do at all.

## Basic Usage
import tensorflow as tf
from shapeguard import ShapeGuard

sg = ShapeGuard()

img = tf.ones([64, 32, 32, 3])
flat_img = tf.ones([64, 1024])
labels = tf.ones([64])

# check shape consistency
sg.guard(img, "B, H, W, C")
sg.guard(labels, "B, 1")  # raises error because of rank mismatch
sg.guard(flat_img, "B, H*W*C")  # raises error because 1024 != 32*32*3

# guard also returns the tensor, so it can be inlined
mean_img = sg.guard(tf.reduce_mean(img, axis=0), "H, W, C")

# more readable reshapes
flat_img = sg.reshape(img, 'B, H*W*C')

# evaluate templates
assert sg['H, W*C+1'] == [32, 97]

# attribute access to inferred dimensions
assert sg.dims['B'] == 64

Roadmap

Use callable module pattern from snoop instead of metaclass

so that sg() returns arg

Figure out how to make this torchscript-safe

I think this will need the callable module pattern, or we will need to give up on sg.blah – currently sg is a class, so I think torchscript tries to compile the whole thing? or we have to add @torch.jit.unused everywhere to keep it exportable

See einops and use same syntax?

actually, might be nicer to wrap einops to use shapeguard as input/output around einops operations (see einops.parse_shape)

Expose sg.view()/reshape()/transpose()?

Or see einops

Display the full function call in the debug frame in sg()

a multiline call like

sg( foo )

currently only captures the first line use parso? to find the minimal number of lines that parses See executing, https://github.com/pwwang/python-varname

Add a decorator @sg() that can guard function args

cache results by tensor id/template?

batch size might be smaller on the last batch.

so we actually need to fork from the base/previous batch, but allow B to be changed.

> Can now do this with sg.fork() as f and f.drop("B")

sg.fork(tmp=True)

Sometimes you want to name something and impose a constraint within the block, but it doesn’t apply beyond

> Can now do this with sg.fork() and using lower-case dims

should dynamic named dimensions be stored?

if they should, then should there be a syntax for named but unstored dimensions (for documentation purposes, to handle dimensions that will be different in the future)? eg. _num_gt_targets already works!,

use devtools.debug to produce error message containing the actual tensor name

with sg_fork(stride=)

allow a forked ShapeGuard obj which will create a singleton that

can be reused later

probably need to allow this singleton to update its dims from the

base singleton (maybe use chainmap?)

I think we will want this context manager to activate the forked

shapeguard for all calls within it

support floats instead of int (mainly for interpolation after division or multiplication)

support iterable in sg() “list” mode instead only list

with sg_noop: context manager

so that pl.Trainer.tune() can run with different batch sizes etc.

support no-op mode for sg()

from shapeguard import sg_noop as sg

this isn’t sufficient because it requires changing imports all over the place

See icecream.install() to add sg to builtins

Tests

checkin_fork: another idea is if we need a mechanism for dims inferred within a

fork to propagate up to the base, use uppercase Dims for base, and lowercase dims for forked

Allow externally supplied `dim=val` args to `sg()`

these should be inserted into known_dims before template processing

is this better than interpolation of the value?

it enters known_dims, which it could if we did Dim={var}

we can make sure it’s an int (sometimes floats get interpolated)

None vs -1 for dynamic dimensions

convert to common=None, via shim

What are dynamic dimensions anyway? https://pgaleone.eu/tensorflow/2018/07/28/understanding-tensorflow-tensors-shape-static-dynamic/

Don’t seem to be relevant to Pytorch, so nothing to do here

About

ShapeGuard is a small tool to help with handling shapes in Pytorch and Tensorflow.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 78.6%
  • Jupyter Notebook 21.4%