Using the CPU-only version of JAX (jax 0.1.69, jaxlib 0.1.47) in a fresh conda environment.
import numpy as np
import jax.numpy as jp
x = np.array([1., 1.])
y = np.array([2., 2.])
print(np.column_stack((x, y))) # [[1. 2.], [1. 2.]], expected
print(jp.column_stack((x, y))) # [[1., 1., 2., 2.]], incorrect
Hmm seems close enough up to numerical error...
Just kidding! Thanks for spotting the issue. Our tests must be lacking here.
Tests? Who needs tests? :grin: https://github.com/google/jax/search?q=column_stack
Most helpful comment
Tests? Who needs tests? :grin: https://github.com/google/jax/search?q=column_stack