Tfjs: tf.oneHot() should support string tensors

Created on 19 Jan 2019  路  5Comments  路  Source: tensorflow/tfjs

TensorFlow.js version

0.14.2

Describe the problem or feature request

tf.oneHot() should support string tensors as indices.
Since tensorflow supports this (https://www.tensorflow.org/api_docs/python/tf/one_hot), tfjs should support it as well for consistency.

core feature

All 5 comments

I'm not sure what supporting string tensor in tf.oneHot means. One hot encoding in TensorFlow seems to only accept numeric types.

TypeError: Value passed to parameter 'indices' has DataType string not in list of allowed values: uint8, int32, int64

Could you clarify the example of this problem? Because I want to work on this if I get a chance.

For example, tf.oneHot(tf.tensor1d(['a', 'b', 'c', 'a', 'a'])).print() should give:

    [[1, 0, 0],
    [0, 1, 0],
    [0, 0, 1],
    [1, 0, 0],
    [1, 0, 0]]

@Firenze11 Thanks! I'll try.

@Firenze11 This expression makes sense to me.

tf.oneHot(tf.tensor1d(['a', 'b', 'c', 'a', 'a'])).print()

But TensorFlow core does not support such input. So I'm not sure we should support string type input for oneHot ops.

>>> tf.one_hot(['a', 'b', 'c', 'a', 'a'], 3).eval()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/sasakikai/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/tensorflow/python/ops/array_ops.py", line 2445, in one_hot
    name)
  File "/Users/sasakikai/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/tensorflow/python/ops/gen_array_ops.py", line 4724, in one_hot
    off_value=off_value, axis=axis, name=name)
  File "/Users/sasakikai/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 609, in _apply_op_helper
    param_name=input_name)
  File "/Users/sasakikai/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 60, in _SatisfiesTypeConstraint
    ", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
TypeError: Value passed to parameter 'indices' has DataType string not in list of allowed values: uint8, int32, int64
>>> tf.__version__
'1.12.0'

Since we want to stay aligned with TF Python, we can't add that support. TF doesn't support string as dtype for the indices param.

Was this page helpful?
0 / 5 - 0 ratings