-
-
Notifications
You must be signed in to change notification settings - Fork 223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
transfer learning for resnet50-res512-all #92
Comments
It should be as simple as using this line:
in this script: https://github.com/mlmed/torchxrayvision/blob/master/scripts/transfer_learning.ipynb Also change the resizing to xrv.datasets.XRayResizer(512) so the images are 512x512 |
Hi Joseph,
Thanks for the response! I used densenet121(224 X 224) for transfer
learning and it worked great. I tried ResNet50 and modified the fc layers
for binary classification as below:
model = xrv.models.ResNet(weights="resnet50-res512-all")
model.fc = nn.Sequential(
nn.Linear(2048, 128),
nn.ReLU(inplace=True),
nn.Linear(128,1))
But the script stuck at the training step:
outputs = model(inputs)
The dimension of output becomes (32, 18), I know 32 is the batch size but I
don't know where 18 comes from. Shouldn't it just be 1 instead?
It seems to me the settings for resnet are quite different from densenet. I
am quite new to this and hope to get the resnet work, thank you for helping
out!
Yue
…On Mon, Apr 4, 2022 at 4:22 PM Joseph Paul Cohen ***@***.***> wrote:
It should be as simple as using this line:
model = xrv.models.ResNet(weights="resnet50-res512-all")
in this script:
https://github.com/mlmed/torchxrayvision/blob/master/scripts/transfer_learning.ipynb
Also change the resizing to xrv.datasets.XRayResizer(512) so the images
are 512x512
—
Reply to this email directly, view it on GitHub
<#92 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AQF6CBYSOEXAOGJZ22JQKW3VDNMSTANCNFSM5SOHCWOA>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
Oh sorry I responded to fast and didn't test the code. The resnet loads an internal resnet model inside so the fc is located at model = xrv.models.ResNet(weights="resnet50-res512-all")
model.op_threshs = None # prevent pre-trained model calibration
model.model.fc = torch.nn.Linear(2048,1) # reinitialize classifier
optimizer = torch.optim.Adam(model.model.fc.parameters()) # only train classifier
criterion = torch.nn.BCEWithLogitsLoss() I tested the above code and it seems to train correctly. |
Great library! Would you provide the transfer learning code for resnet50-res512-all as well? Thank you so much!
The text was updated successfully, but these errors were encountered: