US20250307633A1
CONVERTING AND UPTRAINING PERFORMANT TRANSFORMERS WITH FULL ATTENTION USING NORMALIZED RECURRENCE
Publication
Application
Classifications
IPC Classifications
CPC Classifications
Applicants
Toyota Research Institute, Inc.
Inventors
Igor Vasiljevic, Jean Mercat, Sedrick Keh, Achal Dave, Kushal Arora, Thomas Kollar
Abstract
A method may include receiving parameters associated with a pre-trained transformer trained on first training data, modifying an architecture of the pre-trained transformer to generate a modified transformer, the modified transformer replacing a dot-product softmax attention layer with a linear kernel dot product attention layer utilizing Group Normalization, receiving second training data, and training the modified transformer based on the training data.
Figures
Description
CROSS-REFERENCE TO RELATED APPLICATION
[0001]The present specification is based on, and claims priority from U.S. Provisional Application No. 63/571,605, filed Mar. 29, 2024, the disclosure of which is hereby incorporated by reference in its entirety.
TECHNICAL FIELD
[0002]The present specification relates to training large language models, and more particularly, to converting and uptraining performant transformers with full attention using normalized recurrence.
BACKGROUND
[0003]Large language models (LLMs) have become a popular form of generative artificial intelligence in recent years. LLMs can receive a text prompt and generate text in response to the prompt. LLMs are typically trained using transformers, which have a high parallel training efficiency and scaling performance. A transformer is trained using training data comprising a large number of tokens (e.g., sequences of text). However, the training efficiency of transformers comes at the expense of an inference cost that scales linearly with the number of tokens. As such, the memory intensive nature of transformers has led to renewed interest in recurrent neural networks (RNN).
[0004]RNNs are another form of neural network that can be used for sequence modeling tasks. However, RNNs do not have the training efficiency and scaling performance of transformers. As such, transformers have largely displaced RNNs in sequence modeling tasks. However, while the inference cost of transformers scales linearly with the number of tokens, RNNs have a fixed cost at inference. Accordingly, there is a need for LLMs that have the training benefits of transformers, and the inference benefits of RNNs.
SUMMARY
[0005]In one embodiment, a method may include receiving parameters associated with a pre-trained transformer trained on first training data, modifying an architecture of the pre-trained transformer to generate a modified transformer, the modified transformer replacing a dot-product softmax attention layer with a linear kernel dot product attention layer utilizing Group Normalization, receiving second training data, and training the modified transformer based on the training data.
[0006]In another embodiment, a computing device may comprise one or more processors configured to receive parameters associated with a pre-trained transformer trained on first training data, modify an architecture of the pre-trained transformer to generate a modified transformer, the modified transformer replacing a dot-product softmax attention layer with a linear kernel dot product attention layer utilizing Group Normalization, receive second training data, and train the modified transformer based on the training data.
BRIEF DESCRIPTION OF THE DRAWINGS
[0007]The embodiments set forth in the drawings are illustrative and exemplary in nature and are not intended to limit the disclosure. The following detailed description of the illustrative embodiments can be understood when read in conjunction with the following drawings, where like structure is indicated with like reference numerals and in which:
[0008]
[0009]
[0010]
[0011]
[0012]
[0013]
[0014]
DETAILED DESCRIPTION
[0015]The embodiments disclosed herein include a method of converting a transformer into an RNN. In embodiments, an LLM trained on high-quality, proprietary datasets (which are not available for linear model pre-training) may be used as a starting point (e.g., Llama 2). Such models are often trained on trillions of tokens. This pre-trained model may then be fine-tuned or uptrained for a small fraction of pre-training tokens with publically available data to obtain linear models that are competitive with the best linear transformers for a fraction of the compute cost, as disclosed herein.
[0016]Recurrent neural networks are a type of artificial neural network that may be used for sequential data processing. Unlike feed forward neural networks, which process data in a single pass, RNNs process data across multiple time steps, making them well-adapted for modeling and processing time series data such as text and speech. In operation, RNNs generate a sequence of hidden states as a function of the previous hidden state for a particular input position. This inherently sequential nature precludes parallelization within training examples, which becomes critical at longer sequence lengths, as memory constraints limit batching across examples.
[0017]In order to overcome the limitations of RNNs, the transformer model was developed. Rather than relying on recurrence, transformers rely on an attention mechanism to draw global dependencies between input and output. The transformer model allows for significantly more parallelization and has been used to train large language models.
[0018]In a transformer architecture used for LLMs, text is converted to numerical representations called tokens. Each token is then converted into a vector using word embedding. At each layer, each token is contextualized within the scope of a context window with other tokens via a parallel multi-head attention mechanism, which allows the signal for key tokens to be amplified and less important tokens to be diminished.
[0019]
[0020]Each layer 106 of the encoder 102 contains two sub-layers: a multi-head attention layer 110 and a fully connected feed forward layer 112. A residual connection may be employed around each of the two sub-layers, followed by layer normalization. That is, the output of each sub-layer is LayerNorm (x+Sublayer(x)) where Sublayer(x) is the function implemented by the sub-layer itself.
[0021]Each layer 108 of the decoder 104 comprises three sub-layers: a masked multi-head attention layer 114, a multi-head attention layer 116, and a fully connected feed forward network 118. The masked multi-head attention layer 114 performs multi-head attention over the output of the encoder stack. Similar to the encoder 102, the decoder 104 employs residual connections around each of the sub-layers, followed by layer normalization. The self-attention sub-layer in the decoder stack is modified to prevent positions from attending to subsequent positions. This masking, combined with the fact that the output embeddings are offset by one position, ensures that the predictions for position i can depend only on the known outputs at positions less than i. The decoder 104 also includes a linear transformation layer 120 and a softmax layer 122.
[0022]
[0023]In the vanilla transformer 100, the attention is “Scaled Dot-Product Attention”, which can be performed by the scaled dot-product attention layer 130.
[0024]Transformers have outperformed RNNs in natural language generation tasks such as LLMs. However, this comes with a significant computational cost and memory footprint during generation. Since the output is incrementally predicted conditioned on the prefix, generation steps cannot be parallelized over time steps and require quadratic time complexity in sequence length. In particular, the computation of softmax dot product attention in equation (1) costs O(dN2) for a sequence length N of dimension d. The memory consumption in every generation step also grows linearly as the sequence becomes longer. This bottleneck for long sequence generation limits the use of large-scale pre-trained transformers. Accordingly, it may be desirable to use a more efficient transformer that uses a more memory efficient normalization function then the softmax function.
[0025]In equation (1), the self-attention function computes, for every position, a weighted average of the feature representations of all other positions with a weight proportional to a similarity score between the representations. In particular, in the softmax attention of equation (1), the similarity score is the exponential of the dot product between a query and a key. Given that subscripting a matrix with i returns the i-the row as a vector, we can write a generalized attention equation for any similarity function as follows:
[0026]Equation (2) is equivalent to equation 1 if we substitute the similarity function with
Furthermore, a kernel function ϕ(x) can be defined that maps queries and keys to their hidden representations. Given such a kernel, equation (2) can be rewritten as follows:
[0027]The associative property of matrix multiplication can be used to rewrite equation (3) as follows:
[0028]In equation (4), Q and K are decoupled. This means that Σj=1Nϕ(Kj)Vj and Σj=1Nϕ(Kj) can be pre-computed and reused for every query. As such, the computational cost has time and memory complexity O(N) rather than O(N2) for the vanilla transformer 100. Thus, this modified transformer may be considered a linear transformer.
[0029]We can write si=Σj=1iϕ(Kj)VjT and zi=Σj=1iϕ(Kj) so that:
si and zi can be computed from si-1 and zi-1, which means that at test time, we can express this as a recurrence. Using this formulation, each new token can be generated in constant time. Because si and zi can be computed from their past values, they have the form of an RNN hidden state or memory.
[0030]Consider a stream of tokens that we want to generate X=[x1, x2, x3, . . . ]. At inference time, the following update rule may be used, where subscripts denote timestep in the recurrence (calling ki=WKxi, etc.):
The quantity si acts as a constant-size KV cache. Instead of appending new values to the cache, the state is updated. Therefore, the longer the inference sequence, the more computational gain this formulation offers. However, to claim this gain, model performance has to be demonstrated at such long sequence lengths. This architecture allows for O(1) inference, but performance lags vanilla attention transformers for natural language tasks. Furthermore, the normalization term zi leads to unbounded gradients, which is an important stability issue.
[0031]To solve the above problems, in embodiments disclosed herein, a pre-trained transformer is uptrained, as disclosed herein. That is, instead of training a transformer model from scratch, a transformer model that has already been trained may be used as a starting point. This transformer may be a proprietary transformer model that has been trained on high quality training data. This transformer may then be uptrained using MLP kernel attention to convert the transformer into an RNN, as disclosed herein.
[0032]In embodiments, the architecture of the pre-trained transformer may be modified as disclosed herein. In particular, the pre-trained transformer may have the architecture of the vanilla transformer 100, as shown in
[0033]The fully connected layers 202, 204 receive the outputs of the fully connected layers 128, 126, respectively. The fully connected layers 202, 204 comprise parameters that are trained during uptraining, as disclosed herein. The ReLU activation function is applied to the outputs of the fully connected layers 202, 204, such that the output of the MLP 201 can be written as ϕ(x)=ReLU(Wx+b). A rotary position embedding (RoPE) is then applied to the outputs of the MLP 201 such that the similarity function becomes:
[0034]The outputs of the ROPE layers 210, 212 associated with the query matrix Q and the key matrix K, respectively, are multiplied together at operation 214. This output is then multiplied by the output of the fully connected layer 124 associated with the value Matrix V at operation 216.
[0035]This output is then normalized with a Group Normalization operation by the GroupNorm operation 218 instead of dividing by the sum of sim (qi,kj). Group Normalization may be referred to herein as GroupNorm. The linear cell 200 also uses a fixed decay vector y E (0,1)h where h is the number of heads (not shown in
[0036]The new parameters of the linear cell 200 are then trained jointly with the rest of the transformer network having the pre-trained parameters. That is, the entire modified network, including the linear cells 200, is uptrained with additional training data. In the illustrated example, the modified transformer architecture may be uptrained for about 5% to 10% of the token budget for the pre-trained transformer. However, in other examples, other amounts of training data may be used for uptraining.
[0037]The linear cell 200 of
[0038]
[0039]In the example of
[0040]The network interface hardware 306 can be communicatively coupled to the communication path 308 and can be any device capable of transmitting and/or receiving data via a network. Accordingly, the network interface hardware 306 can include a communication transceiver for sending and/or receiving any wired or wireless communication. For example, the network interface hardware 306 may include an antenna, a modem, LAN port, Wi-Fi card, WiMax card, mobile communications hardware, near-field communication hardware, satellite communication hardware and/or any wired or wireless hardware for communicating with other networks and/or devices. In one embodiment, the network interface hardware 306 includes hardware configured to operate in accordance with the Bluetooth® wireless communication protocol. The network interface hardware 306 of the computing device 300 may receive parameters of a pre-trained transformer, as disclosed in further detail below.
[0041]The one or more memory modules 304 include a database 312, a pre-trained transformer reception module 314, an architecture modification module 316, a training data reception module 318, an uptraining module 320, and an inference module 322. Each of the database 312, the pre-trained transformer reception module 314, the architecture modification module 316, the training data reception module 318, the uptraining module 320, and the inference module 322 may be a program module in the form of operating systems, application program modules, and other program modules stored in the one or more memory modules 304. In some embodiments, the program module may be stored in a remote storage device that may communicate with the computing device 300. Such a program module may include, but is not limited to, routines, subroutines, programs, objects, components, data structures and the like for performing specific tasks or executing specific data types as will be described below.
[0042]The database may store parameters of a pre-trained transformer received by the pre-trained transformer reception module 314, parameters of the modified transformer, and training data to uptrain the modified transformer.
[0043]The pre-trained transformer reception module 314 may receive parameters of a pre-trained transformer. As discussed above, a pre-trained transformer (e.g., a proprietary model trained on high-quality training data) may be used as a starting point for uptraining. As such, the pre-trained transformer reception module 314 may receive the parameters of a pre-trained transformer model to be uptrained (e.g., parameters of the vanilla transformer 100).
[0044]The architecture modification module 316 may modify the architecture of the pre-trained transformer associated with the pre-trained parameters received by the pre-trained transformer reception module 314. In particular, as discussed above, the architecture modification module 316 may replace each attention layer of the pre-trained transformer (e.g., layers 110, 114, 116 of the transformer 100 shown in
[0045]The training data reception module 318 may receive training data for uptraining the modified transformer generated by the architecture modification module 316. In particular, as discussed above, the modified transformer may be trained using only a fraction of the tokens used to train the original pre-trained model (e.g., 5% to 10% of the tokens). In some examples, the actual training data used to train the pre-trained model may be available. In these examples, a portion of this training data may be used to uptrain the modified transformer model. However, if the actual training data used to pre-train the transformer is not available (e.g. a proprietary model was used as the pre-trained model), then any set of tokens may be used for uptraining. Ideally, the tokens comprising the training data used to uptrain the modified transformer are drawn from a similar distribution as the tokens used to train the original pre-trained transformer.
[0046]The uptraining module 320 may uptrain the modified transformer generated by the architecture modification module 316 using the training data received by the training data reception module 318. In particular, the weights of the fully connected layers 202 and 204 of each linear cell 200 added by the architecture modification module 316 may be initialized with random values. The other parameters of the modified transformer may initially utilize the pre-trained values received by the pre-trained transformer reception module 314. The entire modified transformer may then be trained using the training data received by the training data reception module 318. The uptraining module 320 may train the modified transformer using known techniques. After the modified transformer is trained by the uptraining module 320, the parameters of the trained model may be stored in the database 312.
[0047]The inference module 322 may receive a query, and generate a response to the query using the trained modified transformer. In the illustrated example, during inference, each linear cell 200 of the modified transformer may be replaced by the RNN cell 220, as shown in
[0048]
[0049]It should now be understood that embodiments described herein are directed to converting and uptraining performant transformers with full attention using normalized recurrence. By modifying a pre-trained transformer to replace softmax normalization with GroupNorm normalization, a linear transformer can be generated that performs inference much more efficiently than a vanilla transformer. Furthermore, only a small fraction of the tokens used to train the original transformer need to be used to uptrain the modified transformer. As such, a modified transformer may be used for LLMs or other applications with more efficient inference capabilities.
[0050]It is noted that the terms “substantially” and “about” may be utilized herein to represent the inherent degree of uncertainty that may be attributed to any quantitative comparison, value, measurement, or other representation. These terms are also utilized herein to represent the degree by which a quantitative representation may vary from a stated reference without resulting in a change in the basic function of the subject matter at issue.
[0051]While particular embodiments have been illustrated and described herein, it should be understood that various other changes and modifications may be made without departing from the spirit and scope of the claimed subject matter. Moreover, although various aspects of the claimed subject matter have been described herein, such aspects need not be utilized in combination. It is therefore intended that the appended claims cover all such changes and modifications that are within the scope of the claimed subject matter.
Claims
What is claimed is:
1. A method comprising:
receiving parameters associated with a pre-trained transformer trained on first training data;
modifying an architecture of the pre-trained transformer to generate a modified transformer, the modified transformer replacing a dot-product softmax attention layer with a linear kernel dot product attention layer utilizing Group Normalization;
receiving second training data; and
training the modified transformer based on the training data.
2. The method of
the pre-trained transformer comprises a vanilla transformer architecture having a plurality of multi-head attention layers using softmax normalization; and
the modified transformer uses a linear cell in place of one or more of the multi-head attention layers, the linear cell using Group Normalization in place of softmax normalization.
3. The method of
a first fully connected layer associated with a query matrix;
a second fully connected layer associated with a key matrix;
a third fully connected layer associated with a value matrix;
a fourth fully connected layer to receive an output of the first fully connected layer; and
a fifth fully connected layer to receive an output of the second fully connected layer.
4. The method of
the first fully connected layer, the second fully connected layer, and the third fully connected layer are included in the pre-trained transformer; and
the fourth fully connected layer and the fifth fully connected layer are not included in the pre-trained transformer.
5. The method of
a first rectified linear unit (ReLU) activation function to be performed on an output of the fourth fully connected layer; and
a second ReLU activation function to be performed on an output of the fifth fully connected layer.
6. The method of
a first rotary position embedding (RoPE) applied to an output of the first ReLU activation function; and
a second RoPE applied to an output of the second ReLU activation function.
7. The method of
a first matrix multiplication operation between an output of the first RoPE and an output of the second RoPE.
8. The method of
a second matrix multiplication operation between an output of the first matrix multiplication operation and an output of the third fully connected layer.
9. The method of
10. The method of
receiving a query;
inputting the query into the trained modified transformer; and
generating a response to the query based on an output of the trained modified transformer.
11. A computing device comprising one or more processors configured to:
receive parameters associated with a pre-trained transformer trained on first training data;
modify an architecture of the pre-trained transformer to generate a modified transformer, the modified transformer replacing a dot-product softmax attention layer with a linear kernel dot product attention layer utilizing Group Normalization;
receive second training data; and
train the modified transformer based on the training data.
12. The computing device of
the pre-trained transformer comprises a vanilla transformer architecture having a plurality of multi-head attention layers using softmax normalization; and
the modified transformer uses a linear cell in place of one or more of the multi-head attention layers, the linear cell using Group Normalization in place of softmax normalization.
13. The computing device of
a first fully connected layer associated with a query matrix;
a second fully connected layer associated with a key matrix;
a third fully connected layer associated with a value matrix;
a fourth fully connected layer to receive an output of the first fully connected layer; and
a fifth fully connected layer to receive an output of the second fully connected layer.
14. The computing device of
the first fully connected layer, the second fully connected layer, and the third fully connected layer are included in the pre-trained transformer; and
the fourth fully connected layer and the fifth fully connected layer are not included in the pre-trained transformer.
15. The computing device of
a first rectified linear unit (ReLU) activation function to be performed on an output of the fourth fully connected layer; and
a second ReLU activation function to be performed on an output of the fifth fully connected layer.
16. The computing device of
a first rotary position embedding (RoPE) applied to an output of the first ReLU activation function; and
a second RoPE applied to an output of the second ReLU activation function.
17. The computing device of
a first matrix multiplication operation between an output of the first RoPE and an output of the second RoPE.
18. The computing device of
a second matrix multiplication operation between an output of the first matrix multiplication operation and an output of the third fully connected layer.
19. The computing device of
20. The computing device
receive a query;
input the query into the trained modified transformer; and
generate a response to the query based on an output of the trained modified transformer.