It would be nice to have a batched image resize (interpolation) operation implemented in JAX, afaik these are scipy.ndimage.zoom in scipy or tensorflow.image.resize_images in Tensorflow.
Good idea! Is scipy.ndimage.zoom the API you want?
In the meantime here's a pure-NumPy snippet you can use:
def interpolate_bilinear(im, rows, cols):
# based on http://stackoverflow.com/a/12729229
col_lo = np.floor(cols).astype(int)
col_hi = col_lo + 1
row_lo = np.floor(rows).astype(int)
row_hi = row_lo + 1
nrows, ncols = im.shape[-3:-1]
def cclip(cols): return np.clip(cols, 0, ncols - 1)
def rclip(rows): return np.clip(rows, 0, nrows - 1)
Ia = im[..., rclip(row_lo), cclip(col_lo), :]
Ib = im[..., rclip(row_hi), cclip(col_lo), :]
Ic = im[..., rclip(row_lo), cclip(col_hi), :]
Id = im[..., rclip(row_hi), cclip(col_hi), :]
wa = np.expand_dims((col_hi - cols) * (row_hi - rows), -1)
wb = np.expand_dims((col_hi - cols) * (rows - row_lo), -1)
wc = np.expand_dims((cols - col_lo) * (row_hi - rows), -1)
wd = np.expand_dims((cols - col_lo) * (rows - row_lo), -1)
return wa*Ia + wb*Ib + wc*Ic + wd*Id
Ah, great, thanks Matt, this is very useful!
Regarding API, I don't have a preference. I used TF before, but following scipy API looks most consistent with the JAX approach.
FYI I just found out that it has a lot of open issues https://github.com/scipy/scipy/issues?utf8=%E2%9C%93&q=is%3Aopen%20is%3Aissue%20label%3Ascipy.ndimage%20zoom
so implementing it correctly may be tricky.
I spent a bit of time looking at this. My guess is that perhaps skimage.transform.resize (or rescale) is a good choice? It looks to me like scipy.ndimage.zoom suffers from some of the same pixel-centering bugs that TF's resize bilinear used to suffer from (https://hackernoon.com/how-tensorflows-tf-image-resize-stole-60-days-of-my-life-aba5eb093f35)
In case it's handy for anyone, I wrote some pure NumPy code to do aligned "averaging over pixels", which could probably be very easily modified to run in JAX: https://gist.github.com/shoyer/c0f1ddf409667650a076c058f9a17276
This is basically the equivalent of TensorFlow 2.0's tf.image.resize with method='area'.
I suspect that most of the use-cases here would be solved by a handful of methods from tf.image.resize. Perhaps bilinear and area resizing would be enough to start.
Add "nearest" and "bilinear" upscaling to the mix, and it covers most use-cases I've seen.
As for what to verify against/emulate, AFAIK OpenCV's resize (cv2.resize) is the gold standard. See for example this issue in TF.
What is the implementation strategy for such new features in Jax ?
Can it leverages things like tf2xla image_resize_op ?
@rodrigob As a general rule, JAX doesn't depend on TF (only XLA), so we can't reuse TF components. That said, in general if a TF operator has an XLA lowering, you can very easily implement exactly the same lowering in JAX using the lax library which has most XLA operators. (And for an added bonus, the lowerings are usually shorter and more readable in Python.)
It seems that both ResizeBilinearOp and ResizeBilinearGradOp end up calling
ResizeUsingDilationAndConvolution which itself uses (to my surprise) xla::ConvGeneralDilated which indeed is part of lax.
@hawkinsp If I understand correctly, you would suggest translating the C++ code to python up to the to the lax.conv_general_dilated call.
Yes, that's essentially what you'd want to do. It may be surprising that zooming is implemented using a convolution, but that's because TPUs are much much faster at convolution/matmul than they are at other things, so if there's any way to frame an operation as a convolution it's usually going to be the fastest way to implement it on TPUs.
This is an implementation for #1889 using @mattjj code:
from __future__ import print_function, division, absolute_import
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random
import matplotlib.pylab as plt
def interpolate_bilinear(im, rows, cols):
# based on http://stackoverflow.com/a/12729229
col_lo = np.floor(cols).astype(int)
col_hi = col_lo + 1
row_lo = np.floor(rows).astype(int)
row_hi = row_lo + 1
nrows, ncols = im.shape[-3:-1]
def cclip(cols): return np.clip(cols, 0, ncols - 1)
def rclip(rows): return np.clip(rows, 0, nrows - 1)
Ia = im[..., rclip(row_lo), cclip(col_lo), :]
Ib = im[..., rclip(row_hi), cclip(col_lo), :]
Ic = im[..., rclip(row_lo), cclip(col_hi), :]
Id = im[..., rclip(row_hi), cclip(col_hi), :]
wa = np.expand_dims((col_hi - cols) * (row_hi - rows), -1)
wb = np.expand_dims((col_hi - cols) * (rows - row_lo), -1)
wc = np.expand_dims((cols - col_lo) * (row_hi - rows), -1)
wd = np.expand_dims((cols - col_lo) * (rows - row_lo), -1)
return wa*Ia + wb*Ib + wc*Ic + wd*Id
def upsample(img,resize_rate):
nrows, ncols = img.shape[-3:-1]
delta = 0.5/resize_rate
rows = np.linspace(delta,nrows-delta, np.int32(resize_rate*nrows))
cols = np.linspace(delta,ncols-delta, np.int32(resize_rate*ncols))
ROWS, COLS = np.meshgrid(rows,cols,indexing='ij')
img_resize_vec = interpolate_bilinear(img, ROWS.flatten(), COLS.flatten())
img_resize = img_resize_vec.reshape(img.shape[:-3] +
(len(rows),len(cols)) +
img.shape[-1:])
return img_resize
here is some test:
!pip install -q --upgrade tfds-nightly tf-nightly
import tensorflow_datasets as tfds
data_dir = '/tmp/tfds'
# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
random_image = mnist_data['train']['image'][:20,...]
random_image2 = upsample(random_image, 2)
import matplotlib.pylab as plt
plt.imshow(random_image[1,:,:,0])
plt.figure()
plt.imshow(random_image2[1,:,:,0])

I implemented a simple version of nearest-neighbor resizing based on conv_transpose using only np and lax calls, but it only supports integers scales. I am sharing it here to see if it helps anyone.
import jax.numpy as np
from jax.lax import conv_transpose
def resize_nearest(inputs, height_scale, width_scale):
input_channels = inputs.shape[-1]
inputs_nchw = np.transpose(inputs, (0, 3, 1, 2))
flat_inputs_shape = (-1, inputs.shape[1], inputs.shape[2], 1)
flat_inputs = np.reshape(inputs_nchw, flat_inputs_shape)
resize_kernel = np.ones((height_scale, width_scale, 1, 1))
strides = (height_scale, width_scale)
flat_outputs = conv_transpose(flat_inputs, resize_kernel, strides, padding="VALID")
outputs_nchw_shape = (
-1,
input_channels,
height_scale * inputs.shape[1],
width_scale * inputs.shape[2],
)
outputs_nchw = np.reshape(flat_outputs, outputs_nchw_shape)
outputs = np.transpose(outputs_nchw, (0, 2, 3, 1))
return outputs
I added a jax.image.resize function recently. I'm not sure how well it will perform on CPU, but it should work well on GPU and TPU, and supports bilinear, lanczos, and bicubic resizes. The outputs match PIL.
Closing this issue; we can open new issues for further enhancements.
Most helpful comment
I spent a bit of time looking at this. My guess is that perhaps
skimage.transform.resize(orrescale) is a good choice? It looks to me likescipy.ndimage.zoomsuffers from some of the same pixel-centering bugs that TF's resize bilinear used to suffer from (https://hackernoon.com/how-tensorflows-tf-image-resize-stole-60-days-of-my-life-aba5eb093f35)