Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update MixVisionTransformer #975

Merged
merged 4 commits into from
Nov 29, 2024
Merged

Update MixVisionTransformer #975

merged 4 commits into from
Nov 29, 2024

Conversation

brianhou0208
Copy link
Contributor

Hi, @qubvel
This PR introduces support for different output strides in the Mit Encoder. The original Mit Encoder required extra parameters (H, W) to be passed along with the features during forward propagation, which is not compatible with the current self.get_stage format.

This update enables PAN, DeepLabv3, and DeepLabv3+ to support the Mit encoder.

Update

  1. Replaced the original nn.LayerNorm to support different input shapes: (B, C, H, W) or (B, N, C).
  2. Passes between different stages in the (B, C, H, W) format.

Test Code

import torch
import segmentation_models_pytorch as smp

def get_features(name='resnet18', output_stride=32):
    x = torch.rand(1, 3, 256, 256)
    backbone = smp.encoders.get_encoder(name, depth=5, output_stride=output_stride)
    features = backbone(x)
    print(name, output_stride, [f.detach().numpy().shape for f in features])

if __name__ == '__main__':
    torch.manual_seed(0)
    get_features('resnet18', 32)
    get_features('resnet18', 16)
    get_features('resnet18', 8)

    get_features('mit_b0', 32)
    get_features('mit_b0', 16)
    get_features('mit_b0', 8)

    get_features('tu-mobilenetv3_small_050', 32)
    get_features('tu-mobilenetv3_small_050', 16)
    get_features('tu-mobilenetv3_small_050', 8)

output

resnet18 32 [(1, 3, 256, 256), (1, 64, 128, 128), (1, 64, 64, 64), (1, 128, 32, 32), (1, 256, 16, 16), (1, 512, 8, 8)]
resnet18 16 [(1, 3, 256, 256), (1, 64, 128, 128), (1, 64, 64, 64), (1, 128, 32, 32), (1, 256, 16, 16), (1, 512, 16, 16)]
resnet18 8 [(1, 3, 256, 256), (1, 64, 128, 128), (1, 64, 64, 64), (1, 128, 32, 32), (1, 256, 32, 32), (1, 512, 32, 32)]
mit_b0 32 [(1, 3, 256, 256), (1, 0, 128, 128), (1, 32, 64, 64), (1, 64, 32, 32), (1, 160, 16, 16), (1, 256, 8, 8)]
mit_b0 16 [(1, 3, 256, 256), (1, 0, 128, 128), (1, 32, 64, 64), (1, 64, 32, 32), (1, 160, 16, 16), (1, 256, 16, 16)]
mit_b0 8 [(1, 3, 256, 256), (1, 0, 128, 128), (1, 32, 64, 64), (1, 64, 32, 32), (1, 160, 32, 32), (1, 256, 32, 32)]
tu-mobilenetv3_small_050 32 [(1, 3, 256, 256), (1, 16, 128, 128), (1, 8, 64, 64), (1, 16, 32, 32), (1, 24, 16, 16), (1, 288, 8, 8)]
tu-mobilenetv3_small_050 16 [(1, 3, 256, 256), (1, 16, 128, 128), (1, 8, 64, 64), (1, 16, 32, 32), (1, 24, 16, 16), (1, 288, 16, 16)]
tu-mobilenetv3_small_050 8 [(1, 3, 256, 256), (1, 16, 128, 128), (1, 8, 64, 64), (1, 16, 32, 32), (1, 24, 32, 32), (1, 288, 32, 32)]

Copy link
Collaborator

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@brianhou0208 Thanks for one more feature! I really appreciate your contributions! See the comments below

segmentation_models_pytorch/encoders/mix_transformer.py Outdated Show resolved Hide resolved
segmentation_models_pytorch/encoders/mix_transformer.py Outdated Show resolved Hide resolved
segmentation_models_pytorch/encoders/mix_transformer.py Outdated Show resolved Hide resolved
@qubvel qubvel changed the title Updata MixVisionTransformer Update MixVisionTransformer Nov 11, 2024
@qubvel qubvel self-requested a review November 29, 2024 17:32
@qubvel
Copy link
Collaborator

qubvel commented Nov 29, 2024

Hey, @brianhou0208! Can you please rebase/merge the main?

@qubvel
Copy link
Collaborator

qubvel commented Nov 29, 2024

Waiting for the tests and merging! Thanks for the update and iterations!

Copy link
Collaborator

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merging!

@qubvel qubvel merged commit 441a602 into qubvel-org:main Nov 29, 2024
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants