Jax: Incorrect output from jax.numpy.column_stack

Created on 4 Jun 2020  路  2Comments  路  Source: google/jax

Using the CPU-only version of JAX (jax 0.1.69, jaxlib 0.1.47) in a fresh conda environment.

import numpy as np
import jax.numpy as jp

x = np.array([1., 1.])
y = np.array([2., 2.])

print(np.column_stack((x, y)))   # [[1. 2.], [1. 2.]], expected
print(jp.column_stack((x, y)))   # [[1., 1., 2., 2.]], incorrect
bug

Most helpful comment

All 2 comments

Hmm seems close enough up to numerical error...

Just kidding! Thanks for spotting the issue. Our tests must be lacking here.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

fehiepsi picture fehiepsi  路  3Comments

shannon63 picture shannon63  路  3Comments

sursu picture sursu  路  3Comments

kunc picture kunc  路  3Comments

yfji picture yfji  路  3Comments