|
# RETRIEVAL-AUGMENTED REINFORCEMENT LEARNING |
|
|
|
**Anonymous authors** |
|
Paper under double-blind review |
|
|
|
ABSTRACT |
|
|
|
Most deep reinforcement learning (RL) algorithms distill experience into parametric behavior policies or value functions via gradient updates. While effective, |
|
this approach has several disadvantages: (1) it is computationally expensive, (2) |
|
it can take many updates to integrate experiences into the parametric model, (3) |
|
experiences that are not fully integrated do not appropriately influence the agent’s |
|
behavior, and (4) behavior is limited by the capacity of the model. In this paper we |
|
explore an alternative paradigm in which we train a network to map a dataset of |
|
past experiences to optimal behavior. Specifically, we augment an RL agent with |
|
a retrieval process (parameterized as a neural network) that has direct access to a |
|
dataset of experiences. This dataset can come from the agent’s past experiences, |
|
expert demonstrations, or any other relevant source. The retrieval process is trained |
|
to retrieve information from the dataset that may be useful in the current context, to |
|
help the agent achieve its goal faster and more efficiently. We integrate our method |
|
into two different RL agents: an offline DQN agent and an online R2D2 agent. |
|
In offline multi-task problems, we show that the retrieval-augmented DQN agent |
|
avoids task interference and learns faster than the baseline DQN agent. On Atari, |
|
we show that retrieval-augmented R2D2 learns significantly faster than the baseline |
|
R2D2 agent and achieves higher scores. We run extensive ablations to measure the |
|
contributions of the components of our proposed method. |
|
|
|
1 INTRODUCTION |
|
|
|
A host is preparing a holiday meal for friends. They remember that the last time they went to the |
|
grocery store during the holiday season, all of the fresh produce was sold out. Thinking back to this |
|
past experience, they decide to go early! The hypothetical host is employing case-based reasoning |
|
(e.g., Kolodner, 1992; Leake, 1996). Here, an agent recalls a situation similar to the current one and |
|
uses information from the previous experience to solve the current task. This may involve adapting |
|
old solutions to meet new demands, or using previous experiences to make sense of new situations. |
|
|
|
In contrast, a dominant paradigm in modern reinforcement learning (RL) is to learn general purpose |
|
behaviour rules from the agent’s past experience. These rules are typically represented in the |
|
weights of a parametric policy or value function network model. Most deep RL algorithms integrate |
|
information across trajectories by iteratively updating network parameters using gradients that are |
|
computed along individual trajectories (collected online or stored in an experience replay dataset, |
|
Lin, 1992). For example, many off-policy algorithms reuse past experience by “replaying” trajectory |
|
snippets in order to compute weight updates for a value function represented by a deep network |
|
(Ernst et al., 2005; Riedmiller, 2005; Mnih et al., 2015b; Heess et al., 2015; Lillicrap et al., 2015). |
|
|
|
This paradigm has clear advantages but at least two interrelated limitations: First, after learning, an |
|
agent’s past experiences no longer plays a direct role in the agent’s behavior, even if it is relevant |
|
to the current situation. This occurs because detailed information in the agent’s past experience is |
|
lost due to practical constraints on network capacity. Second, since the information provided by |
|
individual trajectories first needs to be distilled into a general purpose parametric rule, an agent may |
|
not be able to exploit the specific guidance that a handful of individual past experiences could provide, |
|
nor rapidly incorporate novel experience that becomes available—it may take many replays through |
|
related traces in the past experiences for this to occur (Weisz et al., 2021). |
|
|
|
|
|
----- |
|
|
|
Figure 1: Retrieval-augmented agent (R2A) architecture: (A) R2A augments the agent with a retrieval process. The retrieval process and |
|
the agent maintain separate internal states, mt and st, respectively. The retrieval process retrieves information relevant to the agent’s current |
|
internal state st from the retrieval batch, which is a pre-processed sample from the retrieval dataset B. The retrieved information ut is used |
|
by the agent process to inform its output (e.g., a policy or value function). (B) A batch of raw trajectories is sampled from the retrieval dataset |
|
_B and encoded (using the same encoder as the agent). Each encoded trajectory is then summarized via forward and a backward summarization_ |
|
functions (section 2.2) and sent to the retrieval process. (C) The retrieval process is parameterized as a recurrent model and the internal |
|
state mt is partitioned into slots. Each slot independently retrieves information from the retrieval batch, which is used to update the slot’s |
|
representation and sent to the agent process in ut. Slots also interact with each other via self-attention. See section 2.3 for more details. |
|
|
|
In this work, we develop an algorithm that overcomes these limitations by augmenting a standard |
|
reinforcement learning agent with a retrieval process (parameterized via a neural network). The |
|
purpose of the retrieval process is to help the agent achieve its objective by providing relevant |
|
contextual information. To this end, the retrieval process uses a learned attention mechanism to |
|
dynamically access a large pool of past trajectories stored in a dataset (e.g., a replay buffer), with |
|
the aim of integrating information across these. The proposed algorithm (R2A), shown in Figure 1, |
|
enables an agent to retrieve information from a dataset of trajectories. The high-level idea is to have |
|
two different processes. First, the retrieval process, makes a “query” to search for relevant contextual |
|
information in the dataset. Second, the agent process performs inference and learning based on the |
|
information provided by the retrieval process. These two processes have different internal states |
|
but interact to shape the representations and predictions of each other: the agent process provides |
|
the relevant context, and the retrieval process uses the context and its own internal state to generate |
|
a query and retrieve relevant information, which is in turn used by the agent process to shape the |
|
representation of its policy and value function (see Fig. 1A). Our proposed retrieval-augmented RL |
|
paradigm could take several forms. Here, we focus on a particular instantiation to assay and validate |
|
our hypothesis that learning a retrieval process can help an RL agent achieve its objectives. |
|
|
|
**Summary of experimental results. We want RL algorithms that are able to adapt to the available** |
|
data source and usefully ingest any dataset. Hence, we test the performance of the proposed method |
|
in three different scenarios. First, we evaluate it on Atari games in a single task setting. We build upon |
|
R2D2 (Kapturowski et al., 2018), a state-of-the-art off-policy RL algorithm. Second, we evaluate it |
|
on a multi-task offline RL environment, using DQN (Mnih et al., 2013) as the RL algorithm, where |
|
the data in the queried dataset belongs to the same task. Third, we evaluate it on a multi-task offline |
|
RL environment where the data in the dataset comes from multiple tasks. In all these cases, we show |
|
that R2A learns faster and achieves higher reward compared to the baseline. |
|
|
|
2 RETRIEVAL-AUGMENTED AGENTS |
|
|
|
We now present our method for augmenting an RL agent with a retrieval process, thereby reducing |
|
the agent’s dependence on its model capacity, and enabling fast and flexible use of past experiences. |
|
A retrieval-augmented agent (R2A) consists of two main components: (1) the retrieval process, |
|
which takes in the current state of the agent, combines this with its own internal state, and retrieves |
|
relevant information from an external dataset of experiences; and (2) a standard reward-maximizing |
|
RL agent, which uses the retrieved information to improve its value or policy estimates. See Figure 1 |
|
for an overview. The retrieval process is trained to retrieve information that the agent can use to |
|
improve its performance, without explicit knowledge of the agent’s policy. Importantly, the retrieval |
|
process has its own internal state, which enables it to integrate and combine information across |
|
|
|
|
|
----- |
|
|
|
retrievals. In the following, we focus on value-based methods, such as DQN (Mnih et al., 2015a) and |
|
R2D2 (Kapturowski et al., 2018), but our approach is equally applicable to policy-based methods. |
|
|
|
2.1 RETRIEVAL-AUGMENTED AGENT |
|
|
|
Formally, the agent receives an input xt at each timestep t. Each input is processed by a neural |
|
encoder (e.g., a resnet if the input is an image) to obtain an abstract internal state for the agent |
|
**_st = fθ[enc][(][x][t][)][. For clarity, we focus here on the case of a single vector input, however, each input]_** |
|
could also include the history of past observations, actions, and rewards, as is the case when fθ[enc] is a |
|
recurrent network. These embeddings are used by the agent and retrieval processes. The retrieval |
|
process operates on a dataset B = {(xt, at, rt), . . ., (xt+T, at+T, rt+T )} of l-step trajectories, for |
|
_l ≥_ 1. This dataset could come from other agents or experts, as in offline RL or imitation learning, or |
|
consist of the growing set of the agent’s own experiences. Then, a retrieval-augmented agent (R2A) |
|
consists of the retrieval process and the agent process, parameterized by θ = {θ[enc], θ[retr], θ[agent]}, |
|
|
|
**Retrieval process fθ,[retr]B** [:][ m][t][−][1][,][ s][t] _[7→]_ **_[m][t][,][ u][t]_** |
|
|
|
**Agent process fθ[agent]** : st, ut _Qθ(st, ut, a)_ |
|
_7→_ |
|
|
|
**_Retrieval Process. The retrieval process is parameterized as a neural network and has an internal_** |
|
state mt. The retrieval process takes in the current abstract state of the agent process st and its own |
|
previous internal state mt 1 and uses these to retrieve relevant information from the dataset, which |
|
_−_ _B_ |
|
it then summarizes in a vector ut, and also updates its internal state mt. |
|
|
|
**_Agent Process. The current state of the agent st and the information from the retrieval process ut is_** |
|
then passed to the action-value function, itself used to select external actions. |
|
|
|
The above defines a parameterization for a retrieval-augmented agent. For retrieval to be effective, |
|
the retrieval process needs to: (1) be able to efficiently query a large dataset of trajectories, (2) learn |
|
and employ a similarity function to find relevant trajectories, and (3) encode and summarize the |
|
trajectories in a manner that allows efficient discovery of relevant past and future information. |
|
|
|
Below, we explain how we achieve these desiderata. At a high-level, to reduce computational |
|
complexity given a experience dataset of hundreds of thousands of trajectories, R2A operates on |
|
samples from the dataset. R2A then encodes and summarizes the trajectories in these samples |
|
using auxiliary losses and bi-directional sequence models to enable efficient retrieval of temporal |
|
information. Finally, R2A uses attention to select semantically relevant trajectories. |
|
|
|
2.2 RETRIEVAL BATCH SAMPLING AND PRE-PROCESSING. |
|
|
|
**_Sampling a retrieval batch from the retrieval dataset. To reduce the computational complexity,_** |
|
R2A uniformly samples a large batch of past experiences from the retrieval dataset, and only uses the |
|
sampled batch for retrieving information. We denote the sampled batch as “retrieval batch” and the |
|
number of trajectories in the retrieval batch as nretrieval. |
|
|
|
**_Encoding and forward-backward summarization of the retrieval dataset and corresponding aux-_** |
|
**_iliary losses. Since the agent’s internal state extracts information from observations which relate to_** |
|
the task at hand, we choose to re-encode the raw experiences in the ”retrieval batch” using the agent |
|
encoder module (i.e., fθ[enc][). However, this representation is a function only of past observations (i.e.,] |
|
it’s a causal representation) and may not be fully compatible with the needs of the retrieval operation. |
|
For that reason, we propose to further encode the retrieved batch of information, by additionally |
|
learning a summarization function, applied on the output of the encoder module, and which captures |
|
information about the past and the future within a particular trajectory by using a bi-directional model |
|
(e.g., parameterized as a bi-directional RNN or a Transformer). |
|
|
|
**Forward Summarizer fθ[fwd]** : (s1, . . ., st) **_ht_** |
|
_7→_ |
|
**Backward Summarizer fθ[bwd]** : (sT, . . ., st) **_bt_** |
|
_7→_ |
|
|
|
For each trajectory in the retrieval batch, we represent each time-step within a trajectory by a set |
|
of two vectors hi,t and bi,t (Figure 6 in the appendix) where hi,t summarizes the past (i.e., from |
|
|
|
|
|
----- |
|
|
|
**Algorithm 1 One timestep of a retrieval-augmented agent (R2A).** |
|
|
|
**_Input: Current input xt, previous retrieval process state mt_** 1 = **_mt_** 1,k _k_ 1, . . ., nf, |
|
_−_ _{_ _−_ _|_ _∈{_ _}}_ |
|
dataset of l-step trajectories B = {(x[i]t[,][ h][i]t[,][ b][i]t[, a][i]t[, r]t[i][)][ . . .][ (][x][i]t+l[,][ h][i]t+l[,][ b][i]t+l[, a][i]t+l[, r]t[i]+l[)][}] |
|
for l ≥ 1 and 1 ≤ _i ≤_ _ntraj, where h and b are the outputs of the forward & backward summarizers._ |
|
|
|
**Encode the current input at time-step t.** |
|
**_st = fθ[enc][(][x][t][)]_** |
|
|
|
**Step 1: Compute the query. For all 1 ≤** _k ≤_ _nf_, compute |
|
**_m[k]t_** 1 [= GRU]θ **_st, m[k]t_** 1 |
|
_−_ _−_ |
|
|
|
**_qt[k]_** [=][ f]query[(]m[c][k]t 1[)] |
|
_− _ |
|
|
|
c |
|
|
|
**Step 2: Identify the most relevant trajectories. For all 1 ≤** _k ≤_ _nf_ _, 1 ≤_ _j ≤_ _l and 1 ≤_ _i ≤_ _ntraj,_ |
|
**_κi,j = (h[i]j[W]ret[ e]_** [)][T] |
|
|
|
_ℓ[k]i,j_ [=] **_qt[k]√[κ]d[i,j]e_** |
|
|
|
_αi,j[k]_ [= softmax] _ℓ[k]i,j_ . |
|
Given scores α, the top-ktraj trajectories (resp. top-kstates states) are selected and denoted by _t[k]_ [(resp.][ S]t[k][).] |
|
|