Skip to content

Commit

Permalink
Merge pull request #18 from nekitmm/tf2
Browse files Browse the repository at this point in the history
Support for 16bit and 8bit TIFF files!
  • Loading branch information
nekitmm authored Sep 13, 2020
2 parents 7c7e12e + 735207f commit d9aa3ff
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 25 deletions.
16 changes: 10 additions & 6 deletions starnet_v1_TF2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -133,18 +133,22 @@
"starnet.plot_history(last = 100000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<h4>Image transformation example. Works with 8bit and 16bit TIFF files</h4>"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"name = \"rgb_test5.tif\"\n",
"output = starnet.transform(name)\n",
"if starnet.mode == 'RGB':\n",
" img.fromarray(output.astype('uint8'), mode = 'RGB').save(name + '_starless.tif')\n",
"else:\n",
" img.fromarray(output.astype('uint8'), mode = 'L').save(name + '_starless.tif')"
"in_name = \"rgb_test5_16.tif\"\n",
"out_name = \"rgb_test5_16_starless.tif\"\n",
"starnet.transform(in_name, out_name)"
]
},
{
Expand Down
25 changes: 18 additions & 7 deletions starnet_v1_TF2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import tensorflow.keras.layers as L
import copy
import pickle
import tifffile as tiff

from matplotlib import pyplot as plt

Expand Down Expand Up @@ -303,8 +304,15 @@ def save_model(self, weights_filename, history_filename = None):
with open(history_filename + '_' + self.mode + '.pkl', 'wb') as f:
pickle.dump(self.history, f)

def transform(self, name):
image = np.array(img.open(name), dtype = np.float32)
def transform(self, in_name, out_name):
data = tiff.imread(in_name)
input_dtype = data.dtype
if input_dtype == 'uint16':
image = (data / 255.0 / 255.0).astype('float32')
elif input_dtype == 'uint8':
image = (data / 255.0).astype('float32')
else:
raise ValueError('Unknown image dtype:', data.dtype)

if self.mode == 'Greyscale' and len(image.shape) == 3:
raise ValueError('You loaded Greyscale model, but the image is RGB!')
Expand Down Expand Up @@ -334,8 +342,6 @@ def transform(self, name):
image = np.concatenate((image, image[:, (w - offset) :, :]), axis = 1)
image = np.concatenate((image[:, : offset, :], image), axis = 1)

image /= 255

image = image * 2 - 1

output = copy.deepcopy(image)
Expand All @@ -353,10 +359,15 @@ def transform(self, name):
output = np.clip(output, 0, 1)

if self.mode == 'Greyscale':
return output[offset:-(offset+dh), offset:-(offset+dw), 0] * 255
output = output[offset:-(offset+dh), offset:-(offset+dw), 0]
else:
return output[offset:-(offset+dh), offset:-(offset+dw), :] * 255

output = output[offset:-(offset+dh), offset:-(offset+dw), :]

if input_dtype == 'uint8':
tiff.imsave(out_name, (output * 255).astype('uint8'))
else:
tiff.imsave(out_name, (output * 255 * 255).astype('uint16'))

def _generator(self, m):
layers = []

Expand Down
25 changes: 13 additions & 12 deletions starnet_v1_TF2_transform.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -12,6 +12,8 @@
"tf.get_logger().setLevel(logging.ERROR)\n",
"from starnet_v1_TF2 import StarNet\n",
"\n",
"import tifffile as tiff\n",
"\n",
"%reload_ext autoreload\n",
"%autoreload 2"
]
Expand All @@ -20,38 +22,37 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"<p>Load a model. There are two modes - RGB and Greyscale.</p>"
"<h3>Load a model. There are two modes - RGB and Greyscale.</h3>"
]
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"starnet = StarNet(mode = 'RGB', window_size = 512, stride = 128)\n",
"starnet = StarNet(mode = 'Greyscale', window_size = 512, stride = 128)\n",
"starnet.load_model('./weights', './history')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<p>Load and transform an image.</p>"
"<h3>Load and transform an image.</h3>\n",
"\n",
"<b>Now can work with either 8bit or 16bit TIFF files!</b>"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"name = \"rgb_test5.tif\"\n",
"output = starnet.transform(name)\n",
"if starnet.mode == 'RGB':\n",
" img.fromarray(output.astype('uint8'), mode = 'RGB').save(name + '_starless.tif')\n",
"else:\n",
" img.fromarray(output.astype('uint8'), mode = 'L').save(name + '_starless.tif')"
"in_name = \"rgb_test5_L16.tif\"\n",
"out_name = \"rgb_test5_L16_starless.tif\"\n",
"starnet.transform(in_name, out_name)"
]
},
{
Expand Down

0 comments on commit d9aa3ff

Please sign in to comment.