Source code for rl.memory

from __future__ import absolute_import
from collections import namedtuple
from rl.utils.memory import RingBuffer, sample_batch_indexes
import numpy as np
import pickle

# This is to be understood as a transition: Given `state0`, performing `action`
# yields `reward` and results in `state1`, which might be `terminal`.
Experience = namedtuple('Experience',
                        'state0, action, reward, state1, terminal1')

# A batch
# It stores data element-wise, instead of experience-wise
Batch = namedtuple("Batch", ("state0", "action", "reward", "state1",

[docs]class Memory(object): """ Abstract memory class """ def __init__(self, env): self.env = env
[docs] def sample(self, batch_size): """ Get a sample from the memory :param int batch_size: size of the batch :return: A :class:`Batch` object """ raise NotImplementedError()
[docs] def append(self, experience): """Add the experience to the memory""" raise NotImplementedError()
[docs]class SimpleMemory(Memory): """ A simple memory directly storing experiences in a circular buffer Data is stored directly as an array of :class:`Experience`""" def __init__(self, env, limit): super(SimpleMemory, self).__init__(env) self.buffer = RingBuffer(limit)
[docs] def get_idxs(self, idxs, batch_size): """Get a non-contiguous series of indexes""" # Allocate memory state0_batch = np.empty((batch_size, self.env.observation_space.dim)) action_batch = np.empty((batch_size, self.env.action_space.dim)) reward_batch = np.empty((batch_size, 1)) terminal1_batch = np.empty((batch_size, 1), dtype=bool) state1_batch = np.empty((batch_size, self.env.observation_space.dim)) for batch_index, memory_index in enumerate(idxs): experience = self.buffer[memory_index] state0_batch[batch_index, :] = experience.state0 action_batch[batch_index, :] = experience.action reward_batch[batch_index, :] = experience.reward terminal1_batch[batch_index, :] = experience.terminal1 state1_batch[batch_index, :] = experience.state1 batch = Batch( state0=state0_batch, action=action_batch, reward=reward_batch, terminal1=terminal1_batch, state1=state1_batch) return batch
def sample(self, batch_size, batch_idxs=None): available_samples = len(self) if batch_size > available_samples: raise(IndexError("Not enough elements in the memory (currently {}) to sample a batch of size {}".format(len(self), batch_size))) if batch_idxs is None: # Draw random indexes such that we have at least a single entry before each # index. batch_idxs = sample_batch_indexes(0, available_samples - 1, size=batch_size) batch_idxs = np.array(batch_idxs) + 1 return (self.get_idxs(batch_idxs, batch_size=batch_size)) def append(self, experience): self.buffer.append(experience) @classmethod
[docs] def from_file(cls, env, limit, file_path): """Create a memory from a pickle file""" with open(file_path, "rb") as fd: memory_database = pickle.load(fd) memory = cls(limit=limit, env=env) for experience in memory_database: memory.append(Experience(*experience)) return(memory)
[docs] def save(self, file): """Dump the memory into a pickle file""" print("Saving memory") with open(file, "wb") as fd: pickle.dump(self.buffer.dump(), fd)
[docs] def dump(self): """Get the memory content as a single array""" return(self.buffer.dump())
def __len__(self): return(len(self.buffer))