Pysyft: Removing the head torch tensor wrapper

Created on 26 Nov 2018  路  14Comments  路  Source: OpenMined/PySyft

This is a sub issue of #1657 questioning the design of tensor chains.

If I understand correctly, in "chain" structure of torch1 Pysyft, local tensor chains will not loop which is great, but we will have a empty wrapper at the head.

For example if I want to locally log operations of a tensor with a LogTensor, I'll have:

[Empty Torch Tensor]
  v
[LogTensor]
  v
[Torch Tensor]

Similarly, when I send a pointer I'll have:

[Empty Torch Tensor]
  v
[PointerTensor]

This means that we're handling a bunch of empty torch tensors as wrappers. This adds not only complexity as it makes the chain longer but also adds some overhead related to allocating those empty tensors (even if we could argue that we could have a store of pre-allocated wrappers, at the expense of extra code complexity)

Is it really necessary to have this empty top wrapper? It seems like just removing it wouldn't change anything in the code behavior. Additionally, it's not likely that exposing a non torch tensor as a head object to the user would be a problem, as long as he _thinks_ that it's one.

Help Wanted

All 14 comments

The main reason for having an empty torch tensor at the head is for compatibility with the built in torch modules. They break if you pass them a non torch tensor. Ideally, if we could subclass torch tensor, we would not need these header tensors but that functionality is not yet available in PyTorch.

Oh I see. Do you have an example of torch module which would break? Usually we also hook them as well

I unfortunately dont know one of off the top of my head. Ill see if I can set up an minimal failing example

This implies that when you do x.send() you need to _transform_ a Torch tensor in a Pointer tensor, while previously we were emptying the torch tensor and adding a pointer tensor as a child.

However this is not possible as far as I know.

One solution could be _not_ to modify in place the tensor, at the expense of elegance:

x_ptr = x.send()

Or even

x = x.send(bob)

In which case the original x tensor would be garbage collected, an operation which is far more competitive than the old x.native_set_() method we previously had to change the tensor into an empty head.

I would push for this solution 1) because this avoid having a method changing the nature of an object which usually we tend to avoid in OOP 2) this syntax is extensively used in torch NN module definitions.
But this is not as easy to read than x.send(bob)

Any thoughts?

One more thing also, is that without empty heads we don't need anymore to hook torch tensors methods which is pretty cool for efficiency!

I'm just wondering if the tensor should be the one holding the send method. Wouldn't a worker be more suited to actually send objects to another target worker? I'd like to understand why you chose this way of designing the workflow. I would expect something like this:

my_worker = Worker()
tensor = torch.Tensor()
my_worker.register_worker(remote_worker_address)

tensor_ptr = my_worker.send(tensor)

Either way, I imagine having a wrapper at the top of the chain, that replaces a tensor with a pointer after send is called, might be the most fail-proof method. Otherwise, I could imagine the following scenario:

x = torch.Tensor()
x_ptr = x.send(bob)

//modifying and using a tensor that changed his owner
x = x*x

I do agree that worker should be the one sending the tensor. Makes sense for me!

Calling x after it is sent as your example shows is indeed problematic. But the problem comes from the fact that today by send we mean transfer while it could also mean send but keep a local original, and the original can be destroyed be name reallocation like x = x.send(bob). Hum I'm not sure what would the best

Yeah your totally correct. I think what we are trying to achieve is creating a transfer method.

If we would like to keep a local original of the tensor, while it is stored on the remote worker, we should make sure that the data is consistent across all versions of the tensor. This would force us to constantly update the tensor which can be a big overhead.

I see the following possible scenarios:

  1. Let the user make sure that he is not using tensors that were sent. Maybe we add a sent flag to the tensor and once this is set to true, the tensor prints out warnings if it's used.
  2. (Current situation) Have a syft-wrapper at the top, which swaps the tensor with a pointer once it's sent.
  3. Have a PointerTensor at the top as Wrapper. The Pointer would act like the current [Empty Torch Tensor], but would either point to the local torch tensor or to the remote tensor.

I'd favor solution 2 and 3 over 1. Thoughts?

Hey everyone,

Quick question: is it possible for the worker to implement the send operation and then have a garbage collection operation to clean the "transferred" tensor?

I think some functionality like this is addressed in https://github.com/OpenMined/PySyft/issues/1701. Do you have a specific plan for the implementation already? Otherwise, I would guess that your suggestion can work :D

Hey Ogofo, I don't have a specific plan, but I'll try to have a better look on this part of the code to make sure it makes sense. I'll add updates in this thread.

Oh I see. Do you have an example of torch module which would break? Usually we also hook them as well

All torch modules break. I suppose the most obvious are all the layer types (every layer type, every loss function, as well as any API calls with other torch tensors)

The only costs of having the top level object is.

1) initializing the tensors (more init cost)
2) having to forward commands to .child tensors

I think we can minimize (1) if we use a Pool of wrappers (there's nothing particularly unique about any one of them, so we should be able to re-use them pretty easily), and (2) is quite minimal as well in terms of cost.

All torch modules break. I suppose the most obvious are all the layer types (every layer type, every loss function, as well as any API calls with other torch tensors)

They already break if you give them empty wrappers standing for pointers, that's why we had to hook them, analyse the args / kwargs to search of pointers and if there are some, send the command to the appropriate location.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

jvmncs picture jvmncs  路  3Comments

alberduris picture alberduris  路  3Comments

mgale694 picture mgale694  路  3Comments

deevashwer picture deevashwer  路  4Comments

tblazina picture tblazina  路  3Comments