Hi,
I have trained a CNN with Keras in order to segment specific patterns. It works really well, but now I have to start the "production phase", so my CNN must segment thousands of images.
So for a given image that the CNN has to segment, for each pixel into this image, I have to cut a patch around the pixel, and feed the CNN with all the patches. So far I used this solution:
WindowSize = 23 #patch size
ws2 = WindowSize / 2
image_data = ndimage.imread(image_file).astype(float)
dimensions = image_data.shape
SizeX = dimensions[1]
SizeY = dimensions[0]
imtest = np.ndarray(shape=(SizeX-2*ws2, 1, WindowSize, WindowSize), dtype=np.float32)
for y in range(ws2,SizeY-ws2):
for x in range(ws2,SizeX-ws2):
imtest[x-ws2,0] = image_data[y-ws2:y+ws2+1, x-ws2:x+ws2+1]
So I work row by row, like that I don't loose the patch coordinates. But this solution is really slow. Is there a faster way to do it?
I've also heard about the Generator in Keras, but it seems to be useful to train a generator with the fit_generator function, but not to test on an image and then to segment it because you it does not keep the patches coordinates.
Any idea?
You could maybe take a look at extract_patches_2d in scikit-learn. It should be faster!
Hi, thank you for your answer.
I just made the test, and this solution is as fast as what I did, but it explodes the memory!
If the image I want to segment has for dimensions WxH, and if I want to cut patches of dimensions SxS, then the memory used is (W-S)x(H-S)xSxS.
I will have to segment images with dimensions 6500x4500 pixels, with patches dimensions 31x31.
And that's just the beginning, next it will be volumes.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.
@FiReTiTi you might want to take a look at numpy's stride_tricks.
Here you can see them in use: GameOfLifeStrides.
For your specific problem:
import numpy as np
from numpy.lib.stride_tricks import as_strided
img = np.arange(6500 * 4500, dtype=np.float32).reshape(6500, 4500)
img_strided = as_strided(img, shape=(6500 - 30, 4500 - 30, 31, 31),
strides=img.strides + img.strides, writeable=False)
# img_strided.shape == (6470, 4470, 31, 31)
for img_patch in img_strided.reshape(-1, 31, 31):
# img_patch.shape == (patch_height, patch_width)
...
Thanks
Most helpful comment
@FiReTiTi you might want to take a look at numpy's stride_tricks.
Here you can see them in use: GameOfLifeStrides.
For your specific problem: