Jax: Should ShapeDtypeStruct.size exist?

Created on 6 Feb 2020  路  5Comments  路  Source: google/jax

ShapeDtypeStruct currently has dtype and shape. We could add numpy's size property via

class ShapeDtypeStruct(object):
  __slots__ = ["shape", "dtype"]
  def __init__(self, shape, dtype):
    self.shape = shape
    self.dtype = dtype

  @property
  def size(self):
    return np.prod(self.shape, dtype=int)

Do you want this patch?

enhancement

Most helpful comment

Rather than using attrs, I would rather require Python 3.6+ with the dataclasses backport.

All 5 comments

+1, maybe ndim too? Also it would be nice to have __repr__ :smile: (maybe we can make this a NamedTuple ?)

NamedTuples are bad:

shape, dtype = struct  # WAT

But yes to __repr__; I hit the lack of that just a moment ago. :)

Yeah, I recently learned how bad namedtuples are! What confused me was this kind of behavior:

Point = namedtuple('Point', ['x', 'y'])
hash(Point((2, 3), onp.float32)) == hash(ShapeDtypeStruct((2, 3), onp.float32))

They're really meant to be just sugar on tuples, without type tags. That's not super relevant, but I thought I'd complain about it at every opportunity.

ShapeDtypeStruct exists because I didn't want to expose much API surface here. ShapedArray does everything we want (and more), but I didn't want to surface that as part of the API.

So we could:

  1. keep ShapeDtypeStruct super minimal, no convenience methods/properties
  2. add these select convenience methods/properties
  3. just use ShapedArray here

I don't have strong feelings between 1 and 2, and I suspect that you both as users of this API might be best able to decide between them. I am very cautious about 3 because I don't like exposing internals, even if it means some redundant code.

WDYT?

Understood re namedtuple, while we wait for dataclasses is it worth taking a dep on attrs to make these pod classes more consistent and easier to read (e.g. generated slots, eq, hash, immutability etc)?

I vote for (2) and agree with not exposing the internal ShapedArray. I think a complete implementation would have: shape, dtype, ndim, size, __len__ and of course __eq__ and __hash__.

Rather than using attrs, I would rather require Python 3.6+ with the dataclasses backport.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

sursu picture sursu  路  3Comments

asross picture asross  路  3Comments

murphyk picture murphyk  路  3Comments

lonelykid picture lonelykid  路  3Comments

rdaems picture rdaems  路  3Comments