-
A naive increase in context length leads to an increase in time and space complexity since both scale quadratically.
-
Extrapolation pertains to the ability of the model to continue to perform well as the sequence length increases beyond the number of tokens seen during training.
- Handling longer contexts necessitates better extrapolation
-
A vanilla transformer has fixed and limited context window. This leads to context fragmentation during training. Thus:
- We cannot capture long-term dependencies between tokens that are of distance greater than the context window. Therefore, information from the first token is lost.
- During inference, the issue now becomes a matter of performance since we maximize the context but this entails moving one token at a time and re-processing tokens that have already been processed.
- The early token curse - perplexity degrades for early tokens in a subsequence since they cannot access many previous context tokens.
Extending Context Memory
- 1 proposes the Compressive Transformer which compresses past memories for long-range sequence learning.
-
This is an improvement to the Transformer-XL model. Instead of discarding previous hidden states, we compress them
-
Let
and be the number of memory and compressive memory slots in the model per layer. The overall input sequence is
and split into fixed-sized windows of size . The model observes As the model moves, its
hidden activations are pushed into FIFO memory with the oldest activations being evicted. Evicted activations are passed to the compression operation
to the oldest memories to . -
Some choices for compression functions
: - Mean Pooling with kernel and stride set to
. - 1D convolution with kernel and stride set to
. - Dilated convolutions
- Preserve the most attended memories.
- Mean Pooling with kernel and stride set to
-
The compression network is trained using an auto-encoding loss with the goal of reconstructing the original memories from compressed memories. The loss is defined using a learned decoder
to minimize Where
and are the original and compressed memories respectively. -
We also add an attention reconstruction loss that reconstructs the content-based attention over memory with content-based attention over compressed memory. This is what we use for
. Define the attention operator as
Let
and be the original and compressed memories. Then the loss is defined by summing across all layers as Where
is the hidden state at the -th layer. -
Compression losses are not mixed with the losses of the main network.
-
The temporal range of the compressed transformer with
layers is The attention cost is
-
- 2 introduces Transformer-XL for learning dependencies beyond a fixed context length.
- Longer context is achieved using recurrent layers coupled with a new relative positional encoding scheme.
- During training, the hidden state sequence from the previous segment(s) is fixed and cached. This is treated as an extended context and models long term dependencies. The current key and value are conditioned on the extended context.
- More formally. Let
be a segment and the hidden state produced by the -th layer for be . Then the -th layer hidden state for is given by Wheredenotes we stop computing the gradient for the argument and denotes concatenation. - We can cache more than the previous hidden state to produce a sequence of hidden states referred to as the memory
.
- We can cache more than the previous hidden state to produce a sequence of hidden states referred to as the memory
- During evaluation, representations from previous segments can be reused.
- Introducing recurrence necessitates the use of relative positional encoding to keep positional encodings coherent. See more in the linked page.
- The temporal range of Transformer-XL with
layers is . The attention cost is
Non-Differentiable External Memory
-
The primary approaches here revolve around introducing Non-Differentiable External Memory via a Key-Value (KV) database.
- Memory being Non-Differentiable (not trained) is essential as otherwise, all keys and values would have to be recomputed. 3
-
3 extends language models with the ability to memorize the internal representations of past inputs. Effectively the models can acquire new knowledge immediately at inference time.
-
The approach relies on kNN lookup. The focus is on unifying attention and retrieval using a decoder-only transformer.
-
The proposed model also includes a Transformer-XL style cache
-
We make use of a kNN-augmented attention layer. For the local context, it performs self-attention. For the global context, it does an approximate kNN search into external memory. This yields
and for the attention result in external memory and local context respectively. Both attention results are combined using a learned gate
Where
is a scalar parameter unique per head. -
For each head, the external memory keeps a cache of the prior KV pairs.
-
Documents that are processed over long time steps induce a Distribution shift.
To reduce the effects of old KV’s becoming stale, keys and queries are normalized such that old and new keys do not differ in magnitude.
-
Unlike prior approaches, we use approximate kNN and learn the retrieval process. We use the same queries for both local and external memory.
-
External memory is observed to provide an improvement (lower perplexity) at scale. The improvement in perplexity seems to be mainly driven by a small percentage of tokens that obtain a large improvement in cross-entropy loss when using the larger memory.
-
A non-memory transformer can be finetuned to use memory.
-
- 4 introduces SPALM (Semi-Parametric Language Model mix a transformer with non-parametric memory. This allows models to obtain information from both its parameters and external memory depending on context.
-
The idea is store short term memory (temporary storage) for comprehending sentences and use long-term memory to store experiences, events and knowledge. Modularizing both facilitates easier training and a better model.
-
The model consists of three main components. Essentially, it combines features introduced by Transformer-XL and KNN-LM.
-
A parametric base model (transformer) that processes the local context.
. -
A short-term memory module (i.e., like Transformer-XL) stores hidden states from an extended context.
-
Like in Transformer-XL, the extended context is the
tokens prior to the current context. That is -
Let
be the hidden state for at layer . The hidden states associated with the current context are and the extended context is , where is the stop gradient function. The two are used as input for the attention mechanism (with relative positional encoding) to obtain the KQV’s to produce
.
-
-
A KV database (i.e., like KNN-LM) stores compressed long term context.
- Keys are compressed vector representations of
which we denote - Values are the output token for that context
. - Unlike in kNN-LMs, here we incorporate the retrieval process within the architecture rather than mixing between the outputs of two pretrained models.
- Keys are compressed vector representations of
-
-
The model combines the current context, the short-term memory, and retrieves past output tokens used in a similar token from long-term memory. This is all done via a gating mechanism.
Suppose the current context is
and we wish to predict . First we get the representation of the context. We use this representation to then perform a
-NN search to retrieve the top values from the database. For each
, obtain a vector representation using the base model’s embedding matrix. The gating mechanism then operates as follows
Where
is a parameter, is the sigmoid. - The long term information
is obtained by applying attention using as the query. - The gate
is dependent on context and is used to decide how much local information or long-term information must be used based on the current context. lets the model adaptively combine short and long term memory. - When training the above, key representations stay constant but the word embedding matrix gets updated.
- The long term information
-
Limitation: Retrieving from the KV database is time consuming.
-
- 5 introduces
-NN LMs which incorporate a kNN model with a pre-trained Language Model. -
The results of their work suggest that Learning similarity between sequences of text is easier than predicting the next word, and that nearest neighbor search is an effective approach for language modeling in the long tail.
- The rationale behind this is that identical sequences of words will have essentially the same distribution over the next word.
- The approach is to then memorize rare and surprising linguistic patterns rather than implicitly encoding them in model parameters.
-
We augment a pre-trained LM with a nearest neighbor retrieval mechanism.
This is done by maintaining a datastore.
Let
be a function that maps context to a fixed vector representation computed by the pre-trained LM. We maintain a datastore using the training examples in the training set . Where
(i.e., it is the prefix). -
At inference, given input context
, we generate the output distributions using the Language model and the context representation . We query the data store with
to retrieve the nearest neighbors using the distance function . The distribution over the kNN is then computed using the negative distances.
-
The final kNN-LM distribution is
-
kNN-LM is especially helpful for cases with rare patterns. For example: incorporating factual knowledge, names and near duplicate sentences. In this case it is easier to memorize these patterns via the representations rather than using next-word prediction.
-
Its effectiveness comes from:
- The Transformer is good at learning a representation function for contexts with an implicit notion of similarity.
- The Transformer has the capacity to memorize all training examples but at the cost of making its representation less generalizable. But, the kNN representations mitigate this by memorizing the training data.
-
Links
-
Zhang et. al Ch. 11 - for everything about the basics of the transformer model.
-
All about Attention - more about attention
Footnotes
-
Rae et al. (2019) Compressive Transformers for Long-Range Sequence Modelling ↩
-
Dai et al. (2019) Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context ↩
-
Wu, Rabe, Hutchins, and Szegedy (2022) Memorizing Transformers ↩ ↩2
-
Yogatama, d’Autume and Kong (2021) Adaptive Semiparametric Language Models ↩
-
Khanelwal et al. (2019) Generalization through Memorization: Nearest Neighbor Language Models ↩