Skip to content

Commit

Permalink
Merge pull request #14 from nekitmm/tf2
Browse files Browse the repository at this point in the history
Tf2
  • Loading branch information
nekitmm authored Aug 17, 2020
2 parents 222e1b4 + 9c89a52 commit 9e8afbd
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 36 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
<h1>Update</h1>

Pushed a new implementation os starnet in TF2.x. The whole implementation is in one file starnet_v1_TF2.py.

I also created a few Jupyter notebooks for ease of use:

1. starnet_v1_TF2_transform.ipynb - loads and transforms an image.
2. starnet_v1_TF2.ipynb - more detailed example that loads a model and shows how to train it (really simple as well I think).


**StarNet** is a neural network that can remove stars from images in one simple step leaving only background.

More technically it is a convolutional residual net with encoder-decoder architecture and with L1, Adversarial and Perceptual losses.
Expand Down
32 changes: 0 additions & 32 deletions starnet_v1_TF2.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,38 +102,6 @@ def _ramp(self, x):
return tf.clip_by_value(x, 0, 1)

def _augmentator(self, o, s):
# rotate
if np.random.rand() < 0.33:
r = np.random.randint(360)
pad = int(self.window_size / 2)
bc = np.random.rand()

if np.random.rand() < 0.50:
resample = img.BICUBIC
elif np.random.rand() < 0.75:
resample = img.BILINEAR
else:
resample = img.NEAREST

Xtmp = copy.copy(o)
Ytmp = copy.copy(s)

if np.random.rand() < 0.90:
Xtmp = np.pad(Xtmp, ((pad, pad), (pad, pad), (0, 0)), mode = 'reflect')
Ytmp = np.pad(Ytmp, ((pad, pad), (pad, pad), (0, 0)), mode = 'reflect')
else:
Xtmp = np.pad(Xtmp, ((pad, pad), (pad, pad), (0, 0)), mode = 'constant', constant_values = bc)
Ytmp = np.pad(Ytmp, ((pad, pad), (pad, pad), (0, 0)), mode = 'constant', constant_values = bc)

Xtmp = img.fromarray(np.uint8(Xtmp * 255))
Ytmp = img.fromarray(np.uint8(Ytmp * 255))

Xtmp = Xtmp.rotate(r, resample = resample)
Ytmp = Ytmp.rotate(r, resample = resample)

o = np.array(Xtmp)[pad:-pad, pad:-pad] / 255
s = np.array(Ytmp)[pad:-pad, pad:-pad] / 255

# flip horizontally
if np.random.rand() < 0.50:
o = np.flip(o, axis = 1)
Expand Down
69 changes: 69 additions & 0 deletions starnet_v1_TF2_transform.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from PIL import Image as img\n",
"import logging\n",
"tf.get_logger().setLevel(logging.ERROR)\n",
"from starnet_v1_TF2 import StarNet"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"starnet = StarNet(mode = 'RGB', window_size = 512)\n",
"starnet.load_model('./weights', './history')"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"name = \"rgb_test5.tif\"\n",
"output = starnet.transform(name)\n",
"if starnet.mode == 'RGB':\n",
" img.fromarray(output.astype('uint8'), mode = 'RGB').save(name + '_starless.tif')\n",
"else:\n",
" img.fromarray(output.astype('uint8'), mode = 'L').save(name + '_starless.tif')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
10 changes: 6 additions & 4 deletions wherearemyweights.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
To download pre-trained weights, follow this link (~500Mb)
To download pre-trained weights, follow the links below:

https://www.dropbox.com/s/atcs42ox4n99w96/starnet_weights.zip?dl=0
For TF1.x implementation:

And extract all the content into root folder (folder with this file).
https://www.dropbox.com/s/atcs42ox4n99w96/starnet_weights.zip?dl=0

For TF2.x implementation:

https://www.dropbox.com/s/lcgn5gvnxpo27s5/starnet_weights2.zip?dl=0

Or you might also use lfs (if you know what that means, otherwise ignore this).
And extract all the content into root folder (folder with this file).

0 comments on commit 9e8afbd

Please sign in to comment.