.. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_prototype_nestedtensor.py: Nested Tensors =============================================================== Nested tensor is very similar to regular tensor, except for the shape: * for a regular tensor, each dimension has a size * for a nested tensor, not all dimensions have regular sizes; some of them are jagged Nested tensors are a natural solution for representing sequential data within various domains: * in NLP, sentences can have variable lengths, so a batch of sentences forms a nested tensor * in CV, images can have variable shapes, so a batch of images forms a nested tensor In this tutorial, we will demonstrate basic usage of nested tensors and motivate their usefulness for operating on sequential data of varying lengths with a real-world example. The nested tensor operations used here have not been released yet. You will have to install the latest nightly to run this tutorial. .. code-block:: default import torch import torch.nn.functional as F device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') Nested Tensor Initialization ---------------- From the Python frontend, a nested tensor can be created from a list of tensors. .. code-block:: default nt = torch.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))], device=device) print(nt) By padding every underlying tensor to the same shape, a nested tensor can be converted to a regular tensor. .. code-block:: default pt = nt.to_padded_tensor(0.0) print(pt) For practical reasons, conceptually we implement nested tensor as a batch of tensors with different shapes, i.e. dimension 0 is assumed to be the batch dimension. Indexing dimension 0 gives back the underlying tensor. .. code-block:: default print("0th underlying tensor:", nt[0], sep='\n') print("last column of 1st underlying tensor:", nt[1, :, -1], sep='\n') Slicing in dimension 0 has not been supported yet. Nested Tensor Operations ---------------- As each operation must be explicitly implemented for nested tensors, operation coverage for nested tensors is currently narrower than that of regular tensors. For now, only basic operations such as index, dropout, softmax, transpose, reshape, linear, bmm are covered. However, coverage is being expanded rapidly. If you need certain operations, please file an `issue `__ to help us prioritize coverage. **reshape** The reshape op is for changing the shape of a tensor. Its full semantics for regular tensors can be found `here `__. For regular tensors, when specifying the new shape, a single dimension may be -1, in which case it is inferred from the remaining dimensions and the number of elements. The semantics for nested tensors are similar, except that -1 no longer infers. Instead, it inherits the old size (here 2 for ``nt[0]`` and 3 for ``nt[1]``). -1 is the only legal size to specify for a jagged dimension. .. code-block:: default nt1 = nt.reshape(2, -1, 2, 3) print(nt1) **transpose** The transpose op is for swapping two dimensions of a tensor. Its full semantics can be found `here `__. Note that nested tensor dimension 0 is special; it is assumed to be the batch dimension, so transposes involving nested tensor dimension 0 are forbidden. .. code-block:: default nt2 = nt1.transpose(1, 2) print(nt2) **others** Other operations have the same semantics as for regular tensors. Applying the operation on a nested tensor is equivalent to applying the operation to the underlying tensor components, with the result being a nested tensor as well. .. code-block:: default nt_mm = torch.nested_tensor([torch.randn((2, 3, 4)), torch.randn((2, 3, 5))], device=device) nt3 = torch.matmul(nt2, nt_mm) print("matmul:", nt3, sep='\n') nt4 = F.dropout(nt3, 0.1) print("dropout:", nt4, sep='\n') nt5 = F.softmax(nt4, -1) print("softmax:", nt5, sep='\n') Why Nested Tensor ---------------- In the age before nested tensor, one has to manually pad each data tensor to the same shape to form a batch as a regular tensor. For example, we have 2 sentences and a vocabulary, then pad with 0. .. code-block:: default sentences = [["goodbye", "padding"], ["embrace", "nested", "tensor"]] vocabulary = {"goodbye" : 1.0, "padding" : 2.0, "embrace" : 3.0, "nested" : 4.0, "tensor" : 5.0} padded_sentences = torch.tensor([[1.0, 2.0, 0.0], [3.0, 4.0, 5.0]]) nested_sentences = torch.nested_tensor([torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0, 5.0])]) print(padded_sentences) print(nested_sentences) Clearly, padding introduces inefficiency. Further, padding with zeros does not correctly treat entries as padding for every operation, e.g. in softmax one has to pad with -inf rather than 0 to ignore specific entries. .. code-block:: default padded_sentences_for_softmax = torch.tensor([[1.0, 2.0, float("-inf")], [3.0, 4.0, 5.0]]) print(F.softmax(padded_sentences_for_softmax, -1)) print(F.softmax(nested_sentences, -1)) Let us take a look at a practical example: the multi-head attention component utilized in `Transformers `__. The nested tensor version is straightforward. .. code-block:: default import math """ Args: query: query of shape (N, L_t, E_q) key: key of shape (N, L_s, E_k) value: value of shape (N, L_s, E_v) nheads: number of heads in multi-head attention W_q: Weight for query input projection of shape (E_total, E_q) W_k: Weight for key input projection of shape (E_total, E_k) W_v: Weight for value input projection of shape (E_total, E_v) W_out: Weight for output projection of shape (E_out, E_total) b_q (optional): Bias for query input projection of shape E_total. Default: None b_k (optional): Bias for key input projection of shape E_total. Default: None b_v (optional): Bias for value input projection of shape E_total. Default: None b_out (optional): Bias for output projection of shape E_out. Default: None dropout_p: dropout probability. Default: 0.0 where: N is the batch size L_t is the target sequence length (jagged) L_s is the source sequence length (jagged) E_q is the embedding size for query E_k is the embedding size for key E_v is the embedding size for value E_total is the embedding size for all heads combined E_out is the output embedding size Returns: attn_output: Output of shape (N, L_t, E_out) """ def mha_nested(query, key, value, nheads, W_q, W_k, W_v, W_out, b_q=None, b_k=None, b_v=None, b_out=None, dropout_p=0.0): N = query.size(0) E_total = W_q.size(0) assert E_total % nheads == 0, "Embedding dim is not divisible by nheads" E_head = E_total // nheads # apply input projection # (N, L_t, E_q) -> (N, L_t, E_total) query = F.linear(query, W_q, b_q) # (N, L_s, E_k) -> (N, L_s, E_total) key = F.linear(key, W_k, b_k) # (N, L_s, E_v) -> (N, L_s, E_total) value = F.linear(value, W_v, b_v) # reshape query, key, value to separate by head # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head) query = query.reshape(-1, -1, nheads, E_head).transpose(1, 2) # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) key = key.reshape(-1, -1, nheads, E_head).transpose(1, 2) # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) value = value.reshape(-1, -1, nheads, E_head).transpose(1, 2) # query matmul key^T # (N, nheads, L_t, E_head) x (N, nheads, L_s, E_head)^T -> (N, nheads, L_t, L_s) keyT = key.transpose(-1, -2) attn_weights = torch.matmul(query, keyT) # scale down attn_weights = attn_weights * (1.0 / math.sqrt(E_head)) # softmax attn_weights = F.softmax(attn_weights, dim=-1) # dropout if dropout_p > 0.0: attn_weights = F.dropout(attn_weights, p=dropout_p) # attention_weights matmul value # (N, nheads, L_t, L_s) x (N, nheads, L_s, E_head) -> (N, nheads, L_t, E_head) attn_output = torch.matmul(attn_weights, value) # merge heads # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total) attn_output = attn_output.transpose(1, 2).reshape(N, -1, E_total) # apply output projection # (N, L_t, E_total) -> (N, L_t, E_out) attn_output = F.linear(attn_output, W_out, b_out) return attn_output The 0-padded tensor version additionally requires masks for more complicated treatments at padded entries. .. code-block:: default """ Args: query: query of shape (N, L_t, E_q) key: key of shape (N, L_s, E_k) value: value of shape (N, L_s, E_v) nheads: number of heads in multi-head attention attn_mask_q: boolean mask indicating locations that should not take part in attention for query, shape (N, L_t) attn_mask_kv: boolean mask indicating locations that should not take part in attention for key and value, shape (N, L_s) W_q: Weight for query input projection of shape (E_total, E_q) W_k: Weight for key input projection of shape (E_total, E_k) W_v: Weight for value input projection of shape (E_total, E_v) W_out: Weight for output projection of shape (E_out, E_total) b_q (optional): Bias for query input projection of shape E_total. Default: None b_k (optional): Bias for key input projection of shape E_total. Default: None b_v (optional): Bias for value input projection of shape E_total. Default: None b_out (optional): Bias for output projection of shape E_out. Default: None dropout_p: dropout probability. Default: 0.0 where: N is the batch size L_t is the target sequence length (padded) L_s is the source sequence length (padded) E_q is the embedding size for query E_k is the embedding size for key E_v is the embedding size for value E_total is the embedding size for all heads combined E_out is the output embedding size Returns: attn_output: Output of shape (N, L_t, E_out) """ def mha_padded(query, key, value, nheads, attn_mask_q, attn_mask_kv, W_q, W_k, W_v, W_out, b_q=None, b_k=None, b_v=None, b_out=None, dropout_p=0.0): N = query.size(0) L_t = query.size(1) L_s = key.size(1) E_total = W_q.size(0) assert E_total % nheads == 0, "Embedding dim is not divisible by nheads" E_head = E_total // nheads # apply input projection # (N, L_t, E_q) -> (N, L_t, E_total) query = F.linear(query, W_q, b_q) # (N, L_s, E_k) -> (N, L_s, E_total) key = F.linear(key, W_k, b_k) # (N, L_s, E_v) -> (N, L_s, E_total) value = F.linear(value, W_v, b_v) # padding-specific step: remove bias from padded entries # in the specific multihead-attention formula it is not necessary to remove these bias # because the -inf padding later on in softmax step can take care of it # but to be general here we demonstrate the bias removal for i in range(N): for j in range(L_t): if attn_mask_q[i, j]: query[i, j, :] = 0.0 for j in range(L_s): if attn_mask_kv[i, j]: key[i, j, :] = 0.0 value[i, j, :] = 0.0 # reshape query, key, value to separate by head # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head) -> (N * nheads, L_t, E_head) query = query.reshape(N, -1, nheads, E_head).transpose(1, 2).reshape(N * nheads, -1, E_head) # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) -> (N * nheads, L_s, E_head) key = key.reshape(N, -1, nheads, E_head).transpose(1, 2).reshape(N * nheads, -1, E_head) # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) -> (N * nheads, L_s, E_head) value = value.reshape(N, -1, nheads, E_head).transpose(1, 2).reshape(N * nheads, -1, E_head) # query bmm key^T # (N * nheads, L_t, E_head) x (N * nheads, L_s, E_head)^T -> (N * nheads, L_t, L_s) keyT = key.transpose(-1, -2) # padding-specific step: add -inf mask for padding in softmax attn_mask = query.new_zeros((N, nheads, L_t, L_s)) for i in range(N): for j in range(L_t): for k in range(L_s): if attn_mask_q[i, j] or attn_mask_kv[i, k]: attn_mask[i, :, j, k] = float("-inf") attn_mask = attn_mask.reshape((N * nheads, L_t, L_s)) attn_weights = torch.baddbmm(attn_mask, query, keyT) # if no padding, it could have been as simple as # attn_weights = torch.bmm(query, keyT) # scale down attn_weights = attn_weights * (1.0 / math.sqrt(E_head)) # softmax attn_weights = F.softmax(attn_weights, dim=-1).nan_to_num_(0.0) # dropout if dropout_p > 0.0: attn_weights = F.dropout(attn_weights, p=dropout_p) # attention_weights bmm value # (N * nheads, L_t, L_s) x (N * nheads, L_s, E_head) -> (N * nheads, L_t, E_head) attn_output = attn_weights.bmm(value) # merge heads # (N * nheads, L_t, E_head) -> (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total) attn_output = attn_output.reshape(N, nheads, -1, E_head).transpose(1, 2).reshape(N, -1, E_total) # apply output projection # (N, L_t, E_total) -> (N, L_t, E_out) attn_output = F.linear(attn_output, W_out, b_out) # padding-specific step: remove output projection bias from padded entries for i in range(N): for j in range(L_t): if attn_mask_q[i, j]: attn_output[i, j, :] = 0.0 return attn_output set hyperparameters following `the Transformer paper `__ .. code-block:: default N = 512 E_q, E_k, E_v, E_total, E_out = 512, 512, 512, 512, 512 nheads = 8 except for dropout probability: set to 0 for correctness check .. code-block:: default dropout_p = 0.0 Let us generate some realistic fake data from Zipf's law. .. code-block:: default import numpy as np def zipf_sentence_lengths(alpha: float, batch_size: int) -> np.ndarray: # generate fake corpus by unigram Zipf distribution # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858 sentence_lengths = np.empty(batch_size, dtype=int) for ibatch in range(batch_size): sentence_lengths[ibatch] = 1 word = np.random.zipf(alpha) while word != 3 and word != 386 and word != 858: sentence_lengths[ibatch] += 1 word = np.random.zipf(alpha) return sentence_lengths alpha = 1.2 sentence_lengths = zipf_sentence_lengths(alpha, N) L_t = np.max(sentence_lengths) L_s = L_t create inputs .. code-block:: default # create parameters W_q, b_q = torch.randn((E_total, E_q), device=device), torch.randn(E_total, device=device) W_k, b_k = torch.randn((E_total, E_k), device=device), torch.randn(E_total, device=device) W_v, b_v = torch.randn((E_total, E_v), device=device), torch.randn(E_total, device=device) W_out, b_out = torch.randn((E_out, E_total), device=device), torch.randn(E_out, device=device) # create nested input queries = [] keys = [] values = [] for i in range(N): l = sentence_lengths[i] s = l queries.append(torch.randn((l, E_q), device=device)) keys .append(torch.randn((s, E_k), device=device)) values .append(torch.randn((s, E_v), device=device)) query = torch.nested_tensor(queries) key = torch.nested_tensor(keys ) value = torch.nested_tensor(values ) # pad input padded_query = query.to_padded_tensor(0.0, (N, L_t, E_q)) padded_key = key .to_padded_tensor(0.0, (N, L_s, E_k)) padded_value = value.to_padded_tensor(0.0, (N, L_s, E_v)) # create attention masks attn_mask_q = torch.zeros((N, L_t), dtype=torch.bool) attn_mask_kv = torch.zeros((N, L_s), dtype=torch.bool) for i in range(N): for j in range(L_t): if padded_query[i, j, :].abs().max().item() == 0.0: attn_mask_q[i, j] = True for j in range(L_s): if padded_key[i, j, :].abs().max().item() == 0.0: attn_mask_kv[i, j] = True check correctness and performance .. code-block:: default import timeit t0 = timeit.default_timer() out_nested = mha_nested( query, key, value, nheads, W_q, W_k, W_v, W_out, b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out, dropout_p=dropout_p) t1 = timeit.default_timer() out_padded = mha_padded( padded_query, padded_key, padded_value, nheads, attn_mask_q, attn_mask_kv, W_q, W_k, W_v, W_out, b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out, dropout_p=dropout_p) t2 = timeit.default_timer() print("nested and padded calculations differ by", (out_nested.to_padded_tensor(0.0, (N, L_t, E_out)) - out_padded).abs().max().item()) print("nested tensor multi-head attention takes", t1 - t0, "seconds") print("padded tensor multi-head attention takes", t2 - t1, "seconds") The nested tensor version avoids wasted computation on padding, so in sequential CPU execution it is faster than padded tensor version as expected. Optimization for multi-threaded environment is underway. For now, performant kernels are provided for specific use cases, e.g. self-attention evaluation by multi-head attention formula. .. code-block:: default # embeddings are assumed to be the same E = E_total mha_lib = torch.nn.MultiheadAttention(E, nheads, batch_first=True, device=device) mha_lib.eval() extract parameters for correctness check .. code-block:: default mha_lib.in_proj_weight.requires_grad_(False) mha_lib.in_proj_bias.requires_grad_(False) mha_lib.out_proj.weight.requires_grad_(False) mha_lib.out_proj.bias.requires_grad_(False) W_q, b_q = mha_lib.in_proj_weight[: E, :], mha_lib.in_proj_bias[: E] W_k, b_k = mha_lib.in_proj_weight[E : 2 * E, :], mha_lib.in_proj_bias[E : 2 * E] W_v, b_v = mha_lib.in_proj_weight[2 * E :, :], mha_lib.in_proj_bias[2 * E :] W_out, b_out = mha_lib.out_proj.weight, mha_lib.out_proj.bias check correctness and performance .. code-block:: default t0 = timeit.default_timer() out_lib, out_lib_weights = mha_lib(query, query, query) t1 = timeit.default_timer() out_nested = mha_nested( query, query, query, nheads, W_q, W_k, W_v, W_out, b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out, dropout_p=dropout_p) t2 = timeit.default_timer() padded_out = mha_padded( padded_query, padded_query, padded_query, nheads, attn_mask_q, attn_mask_q, W_q, W_k, W_v, W_out, b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out, dropout_p=dropout_p) t3 = timeit.default_timer() print("nested general and library calculations differ by", (out_nested.to_padded_tensor(0.0) - out_lib.to_padded_tensor(0.0)).abs().max().item()) print("nested library multi-head attention takes", t1 - t0, "seconds") print("nested general multi-head attention takes", t2 - t1, "seconds") print("padded tensor multi-head attention takes", t3 - t2, "seconds") .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_prototype_nestedtensor.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download :download:`Download Python source code: nestedtensor.py ` .. container:: sphx-glr-download :download:`Download Jupyter notebook: nestedtensor.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_