Stable-baselines: Maybe there is one problem in implementing the class PrioritizedReplayBuffer

Created on 13 Jul 2020  路  9Comments  路  Source: hill-a/stable-baselines

In the file https://github.com/hill-a/stable-baselines/blob/master/stable_baselines/common/buffers.py,
(line 206) total = self._it_sum.sum(0, len(self._storage) - 1)
Use the above code to compute the total priorities and set param end of function self._it_sum.sum to len(self._storage) - 1.

    def _sample_proportional(self, batch_size):
        mass = []
        total = self._it_sum.sum(0, len(self._storage) - 1)
        # TODO(szymon): should we ensure no repeats?
        mass = np.random.random(size=batch_size) * total
        idx = self._it_sum.find_prefixsum_idx(mass)
        return idx

But in the file https://github.com/hill-a/stable-baselines/blob/master/stable_baselines/common/segment_tree.py,
(line 75) the code end -= 1 in the function reduce which is called by the above function self._it_sum.sum also subtract by 1.

    def reduce(self, start=0, end=None):
        """
        Returns result of applying `self.operation`
        to a contiguous subsequence of the array.
            self.operation(arr[start], operation(arr[start+1], operation(... arr[end])))
        :param start: (int) beginning of the subsequence
        :param end: (int) end of the subsequences
        :return: (Any) result of reducing self.operation over the specified range of array elements.
        """
        if end is None:
            end = self._capacity
        if end < 0:
            end += self._capacity
        end -= 1
        return self._reduce_helper(start, end, 1, 0, self._capacity - 1)

_Has it been repeatedly subtracted by 1?_

I simply verified my idea with the following code.

from stable_baselines.common.buffers import PrioritizedReplayBuffer
buffer = PrioritizedReplayBuffer(100, 0.6)
x = np.array([1.])
for _ in range(10):
    buffer.add(x, x, x, x, x)
print(buffer._it_sum.sum(0, len(buffer._storage-1)))#result:9.0
print(buffer._it_sum.sum(0, len(buffer._storage)))#result:10.0

If changing len(buffer._storage-1) to len(buffer._storage), I can get the correct result.
Because I add 10 new data into the buffer, the total priorities I think should be 10.
If I misunderstood the code, please let me know.

bug

All 9 comments

I do not quite understand what is the issue. Could you reformat your example code with triple-ticks (``` like this ```)? Also the suggested change is the same.

I do not quite understand what is the issue. Could you reformat your example code with triple-ticks (like this)? Also the suggested change is the same.

Sorry for that. I have modified my question.

Thank you! Now I see the issue. Yes, I think there is one extra end -= 1, and because of it the last added item is not included in the sampling. This can be tested with

from stable_baselines.common.buffers import PrioritizedReplayBuffer
buffer = PrioritizedReplayBuffer(100, 0.6)
x = np.array([1.])
for i in range(10):
    x = np.array([i])
    buffer.add(x, x, x, x, x)
print(buffer.sample(10, beta=0.5))
print(buffer.sample(10, beta=0.5))
... # You will never see state "9" in sampled experiences (state nor next state)

A PR to fix this would be welcomed :). I am not sure where the extra "-1" should be fixed exactly given the documentation of the functions. Also an update to tests for this change would be nice.

By the way, the codes for calculating weights in the function sample of the class PrioritizedReplayBuffer can be simplified.

change

p_min = self._it_min.min() / self._it_sum.sum()
max_weight = (p_min * len(self._storage)) ** (-beta)
p_sample = self._it_sum[idxes] / self._it_sum.sum()
weights = (p_sample * len(self._storage)) ** (-beta) / max_weight

to

weights2 = (self._it_sum[idxes] / self._it_min.min()) ** (-beta)

This can be derived by simple mathematical derivation.

I also did some experiments to verify this with the following code.

from stable_baselines.common.buffers import PrioritizedReplayBuffer

buffer = PrioritizedReplayBuffer(100, 0.6)
x = np.array([1.])
for i in range(10):
    x = np.array([i])
    buffer.add(x, x, x, x, x)
#update priorities [0.05 0.1  0.15 0.2  0.25 0.3  0.35 0.4  0.45 0.5 ]
buffer.update_priorities(np.arange(10), np.linspace(0.05, 0.5, 10))
data = buffer.sample(10, beta=0.5)
weights1 = data[-2]
idxes = data[-1]
weights2 = (buffer._it_sum[idxes] / buffer._it_min.min()) ** (-0.5)
print(weights1 - weights2)

The result is all 0.
The original codes are more like the formula of the paper, but the simplified codes I think are much faster.

Nice catch! It indeed checks out. I do not know how much faster it would be (the bottleneck is in the segment tree summing), but it is more compact and also easier to understand by just looking at the computation.

Feel free to make a PR that includes these two changes :)

Thank you! Now I see the issue. Yes, I think there is one extra end -= 1, and because of it the last added item is not included in the sampling. This can be tested with

from stable_baselines.common.buffers import PrioritizedReplayBuffer
buffer = PrioritizedReplayBuffer(100, 0.6)
x = np.array([1.])
for i in range(10):
    x = np.array([i])
    buffer.add(x, x, x, x, x)
print(buffer.sample(10, beta=0.5))
print(buffer.sample(10, beta=0.5))
... # You will never see state "9" in sampled experiences (state nor next state)

A PR to fix this would be welcomed :). I am not sure where the extra "-1" should be fixed exactly given the documentation of the functions. Also an update to tests for this change would be nice.

I agree with you. This change is just a little trick. and maybe causes little impact. I would make a PR later. Thank you.

Hello, does this mean that Prioritized Experience Replay in DQN isn't working in Stable Baselines ?

@Jogima-cyber

Dare I say most of it is working, _except_ the last added sample is not included in the random sampling process. Given the number of samples in buffer this is seems like a minuscule error (which still should be fixed!), but I can not say for sure if the effect on learning is small.

@UPUPGOO

Any update on PR for this? I am asking to check if somebody is working on this and wants to make a PR out of it. If not, I can add it.

@Jogima-cyber

Dare I say most of it is working, _except_ the last added sample is not included in the random sampling process. Given the number of samples in buffer this is seems like a minuscule error (which still should be fixed!), but I can not say for sure if the effect on learning is small.

@UPUPGOO

Any update on PR for this? I am asking to check if somebody is working on this and wants to make a PR out of it. If not, I can add it.

Sorry for late replay. I did a PR but it seemed not pass. You can update this. Thank you.

Was this page helpful?
0 / 5 - 0 ratings