Skip to content

Commit

Permalink
update tf1.15 code to play nice with 16 bit tiff
Browse files Browse the repository at this point in the history
  • Loading branch information
NevesLucas committed Sep 16, 2021
1 parent 7079110 commit 73817a4
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 40 deletions.
6 changes: 3 additions & 3 deletions starnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
learning_rates = [0.000002, 0.0000005] # Learning rates: the first is for generator, the second is for discriminator. Usually they are the same,
# but who knows. In the beginning of training suitable values are about 0.0002 and then can be made smaller
# as the model gets better.
stride = 128 # Stride value for image transformation. The smaller it gets, the less artefacts you get in the final image,
stride = 64 # Stride value for image transformation. The smaller it gets, the less artefacts you get in the final image,
# but the more time it takes to transform an image. 100 looks about optimal for now.

if len(sys.argv) > 1:
Expand Down Expand Up @@ -103,8 +103,8 @@
else:
start = time.time()
import transform
transform.transform(image = sys.argv[2],
stride = stride)
transform.transform(imageName = sys.argv[2],
stride = stride)
stop = time.time()
t = float((stop - start) / 60)
if t > 60.0:
Expand Down
146 changes: 109 additions & 37 deletions transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,86 @@
import tensorflow as tf
from PIL import Image as img
import matplotlib.pyplot as plt
from scipy.misc import toimage
import matplotlib
import sys
import time
import model
import starnet_utils
import tifffile as tiff

WINDOW_SIZE = 256 # Size of the image fed to net. Do not change until you know what you are doing! Default is 256
# and changing this will force you to train the net anew.
def transform(self, in_name, out_name):
data = tiff.imread(in_name)
if len(data.shape) > 3:
layer = input("Tiff has %d layers, please enter layer to process: " % data.shape[0])
layer = int(layer)
data = data[layer]

input_dtype = data.dtype
if input_dtype == 'uint16':
image = (data / 255.0 / 255.0).astype('float32')
elif input_dtype == 'uint8':
image = (data / 255.0).astype('float32')
else:
raise ValueError('Unknown image dtype:', data.dtype)

if self.mode == 'Greyscale' and len(image.shape) == 3:
raise ValueError('You loaded Greyscale model, but the image is RGB!')

if self.mode == 'Greyscale':
image = image[:, :, None]

if self.mode == 'RGB' and len(image.shape) == 2:
raise ValueError('You loaded RGB model, but the image is Greyscale!')

if self.mode == 'RGB' and image.shape[2] == 4:
print("Input image has 4 channels. Removing Alpha-Channel")
image = image[:, :, [0, 1, 2]]

offset = int((self.window_size - self.stride) / 2)

h, w, _ = image.shape

ith = int(h / self.stride) + 1
itw = int(w / self.stride) + 1

dh = ith * self.stride - h
dw = itw * self.stride - w

image = np.concatenate((image, image[(h - dh):, :, :]), axis=0)
image = np.concatenate((image, image[:, (w - dw):, :]), axis=1)

h, w, _ = image.shape
image = np.concatenate((image, image[(h - offset):, :, :]), axis=0)
image = np.concatenate((image[: offset, :, :], image), axis=0)
image = np.concatenate((image, image[:, (w - offset):, :]), axis=1)
image = np.concatenate((image[:, : offset, :], image), axis=1)

image = image * 2 - 1

output = copy.deepcopy(image)

for i in range(ith):
for j in range(itw):
x = self.stride * i
y = self.stride * j

tile = np.expand_dims(image[x:x + self.window_size, y:y + self.window_size, :], axis=0)
tile = (self.G(tile)[0] + 1) / 2
tile = tile[offset:offset + self.stride, offset:offset + self.stride, :]
output[x + offset:self.stride * (i + 1) + offset, y + offset:self.stride * (j + 1) + offset, :] = tile



if self.mode == 'Greyscale':
output = output[offset:-(offset + dh), offset:-(offset + dw), 0]
else:
output = output[offset:-(offset + dh), offset:-(offset + dw), :]


def transform(imageName, stride):

def transform(image, stride):

# placeholders for tensorflow
X = tf.placeholder(tf.float32, shape = [None, WINDOW_SIZE, WINDOW_SIZE, 3], name = "X")
Y = tf.placeholder(tf.float32, shape = [None, WINDOW_SIZE, WINDOW_SIZE, 3], name = "Y")
Expand All @@ -52,14 +120,23 @@ def transform(image, stride):

# read input image
print("Opening input image...")
input = np.array(img.open(image), dtype = np.float32)
print("Done!")

# rescale to [-1, 1]
input /= 255
# backup to use for mask
backup = np.copy(input)
input = input * 2 - 1
data = tiff.imread(imageName)
if len(data.shape) > 3:
layer = input("Tiff has %d layers, please enter layer to process: " % data.shape[0])
layer = int(layer)
data = data[layer]

input_dtype = data.dtype
if input_dtype == 'uint16':
image = (data / 255.0 / 255.0).astype('float32')
elif input_dtype == 'uint8':
image = (data / 255.0).astype('float32')
else:
raise ValueError('Unknown image dtype:', data.dtype)

if image.shape[2] == 4:
print("Input image has 4 channels. Removing Alpha-Channel")
image = image[:, :, [0, 1, 2]]


# now some tricky magic
Expand All @@ -69,7 +146,7 @@ def transform(image, stride):

# get size of the image and calculate numbers of iterations needed to transform it
# given stride and taking into account that we will pad it a bit later (+1 comes from that)
h, w, _ = input.shape
h, w, _ = image.shape
ith = int(h / stride) + 1
itw = int(w / stride) + 1

Expand All @@ -78,18 +155,18 @@ def transform(image, stride):
dw = itw * stride - w

# pad image using parts of the image itself and values calculated above
input = np.concatenate((input, input[(h - dh) :, :, :]), axis = 0)
input = np.concatenate((input, input[:, (w - dw) :, :]), axis = 1)
image = np.concatenate((image, image[(h - dh) :, :, :]), axis = 0)
image = np.concatenate((image, image[:, (w - dw) :, :]), axis = 1)

# get image size again and pad to allow offsets on all four sides of the image
h, w, _ = input.shape
input = np.concatenate((input, input[(h - offset) :, :, :]), axis = 0)
input = np.concatenate((input[: offset, :, :], input), axis = 0)
input = np.concatenate((input, input[:, (w - offset) :, :]), axis = 1)
input = np.concatenate((input[:, : offset, :], input), axis = 1)
h, w, _ = image.shape
image = np.concatenate((image, image[(h - offset) :, :, :]), axis = 0)
image = np.concatenate((image[: offset, :, :], image), axis = 0)
image = np.concatenate((image, image[:, (w - offset) :, :]), axis = 1)
image = np.concatenate((image[:, : offset, :], image), axis = 1)

# copy input image to output
output = np.copy(input)
output = np.copy(image)

# helper array just to add fourth dimension to net input
tmp = np.zeros((1, WINDOW_SIZE, WINDOW_SIZE, 3), dtype = np.float)
Expand All @@ -103,30 +180,25 @@ def transform(image, stride):
y = stride * j

# write piece of input image to tmp array
tmp[0] = input[x : x + WINDOW_SIZE, y : y + WINDOW_SIZE, :]
tmp[0] = image[x : x + WINDOW_SIZE, y : y + WINDOW_SIZE, :]

# transform
result = sess.run(outputs, feed_dict = {X:tmp})

# write transformed array to output
output[x + offset : x + stride + offset, y + offset: y + stride + offset, :] = result[0, offset : stride + offset, offset : stride + offset, :]
print("Transforming input image... Done!")

# rescale back to [0, 1]
output = (output + 1) / 2
output = np.clip(output, 0, 1)

# leave only necessary part, without pads added earlier
output = output[offset : - (offset + dh), offset : - (offset + dw), :]
output = output[offset:-(offset + dh), offset:-(offset + dw), :]

print("Saving output image...")
toimage(output * 255, cmin = 0, cmax = 255).save('./' + image + '_starless.tif')
print("Done!")

print("Saving mask...")
# mask showing areas that were changed significantly
mask = (((backup * 255).astype(np.int_) - (output * 255).astype(np.int_)) > 25).astype(np.int_)
mask = mask.max(axis = 2, keepdims = True)
mask = np.concatenate((mask, mask, mask), axis = 2)
toimage(mask * 255, cmin = 0, cmax = 255).save('./' + image + '_mask.tif')
print("Done!")

if input_dtype == 'uint8':
tiff.imsave('./' + imageName + '_starless.tif', (output * 255).astype('uint8'))
else:
tiff.imsave('./' + imageName + '_starless.tif', (output * 255 * 255).astype('uint16'))

print("Done!")

0 comments on commit 73817a4

Please sign in to comment.