I have been using cv2 to histogram equalize my images. I recently found that PIL has a function called ImageOps.equalize(image, mask=None) which does the same thing. This transform has been really useful to me. Since this transform is implemented in PIL, which is a supported backend, I was wondering if it would be a good addition to torchvision.transforms.
@fmassa , I would like to take this up if we want to add this.
Looks like this histogram matching feature has already been requested before in #598 and was implemented in #796 but hasn't been merged.
I've implemented a much simplier version identical to the one in PIL.ImageOps.equalize.
def torch_equalize(image):
"""Implements Equalize function from PIL using PyTorch ops based on:
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py#L352"""
def scale_channel(im, c):
"""Scale the data in the channel to implement equalize."""
im = im[:, :, c]
# Compute the histogram of the image channel.
histo = torch.histc(im, bins=256, min=0, max=255)#.type(torch.int32)
# For the purposes of computing the step, filter out the nonzeros.
nonzero_histo = torch.reshape(histo[histo != 0], [-1])
step = (torch.sum(nonzero_histo) - nonzero_histo[-1]) // 255
def build_lut(histo, step):
# Compute the cumulative sum, shifting by step // 2
# and then normalization by step.
lut = (torch.cumsum(histo, 0) + (step // 2)) // step
# Shift lut, prepending with 0.
lut = torch.cat([torch.zeros(1), lut[:-1]])
# Clip the counts to be in range. This is done
# in the C code for image.point.
return torch.clamp(lut, 0, 255)
# If step is zero, return the original image. Otherwise, build
# lut from the full histogram and step and then index from it.
if step == 0:
result = im
else:
# can't index using 2d index. Have to flatten and then reshape
result = torch.gather(build_lut(histo, step), 0, im.flatten().long())
result = result.reshape_as(im)
return result.type(torch.uint8)
# Assumes RGB for now. Scales each channel independently
# and then stacks the result.
s1 = scale_channel(image, 0)
s2 = scale_channel(image, 1)
s3 = scale_channel(image, 2)
image = torch.stack([s1, s2, s3], 2)
return image
Most helpful comment
Looks like this histogram matching feature has already been requested before in #598 and was implemented in #796 but hasn't been merged.
I've implemented a much simplier version identical to the one in
PIL.ImageOps.equalize.