Synchronizations With TorchRec KeyedJaggedTensor
Introduction
In recommendation systems, sparse features such as user-item interaction ids are often used to model user preferences and item characteristics. These sparse features are then mapped to dense representations through large embedding tables.
However, there are a few challenges when working with sparse features in recommendation systems:
- Different samples may have different numbers of interactions, leading to variable-length input data.
- There are often lots of sparse features being used in recommendation systems.
In a batch of requests, if all sparse features are padded to the same length, the embedding tables will produce many useless embedding vectors. That wastes memory and downstream compute resources. If each sparse feature accesses embedding tables independently, the overhead becomes large when the number of sparse features is large.
TorchRec KeyedJaggedTensor was designed to address these challenges by combining sparse features across samples and across features into one large sparse feature without padding. This eliminates the memory and compute inefficiencies.
Despite its efficiency, KeyedJaggedTensor has several caveats and can be used incorrectly, resulting in worse system performance. One key issue in GPU systems is synchronization. In this blog post, I would like to discuss the main pitfalls of KeyedJaggedTensor and how to use it efficiently on GPU.
TorchRec Data Types
TorchRec has specific input/output data types of its modules to efficiently represent sparse features, including:
JaggedTensor: a wrapper around the lengths or offsets tensor and the values tensor for a single sparse feature.KeyedJaggedTensor: a wrapper that represents multiple sparse features and can be thought of as multipleJaggedTensors.KeyedTensor: a wrapper aroundtorch.Tensorthat allows access to tensor values through keys.
KeyedJaggedTensor can be constructed from a dictionary of JaggedTensors, where the keys are the feature names. The output of KeyedJaggedTensor and EmbeddingBagCollection is KeyedTensor, whose embeddings can be accessed through keys.
Synchronizations With TorchRec KeyedJaggedTensor
When using KeyedJaggedTensor, the critical question is what the output symbolic shape will be for a given input KeyedJaggedTensor. Such an operation could be using KeyedJaggedTensor to access EmbeddingBagCollection, or getting the value tensor corresponding to a specific key in KeyedJaggedTensor. This symbolic shape cannot be derived from the symbolic shapes of the value, lengths, or offsets tensors in KeyedJaggedTensor. In other words, any operation that uses KeyedJaggedTensor as input is data dependent, and the output shape can only be determined from the actual lengths data at runtime. If the value, lengths, and offsets tensors are on GPU, TorchRec has to copy the lengths from GPU to CPU to infer the output shape, which introduces synchronization and can hurt performance.
To mitigate this problem, KeyedJaggedTensor saves key metadata in lists when it is constructed. For some metadata, such as lengths per key, the values can be determined directly from the corresponding JaggedTensor without looking at the flattened lengths tensor. In this way, even after the original JaggedTensors are no longer available, the output shape can still be determined without reading the actual lengths data, which avoids GPU-CPU synchronization when KeyedJaggedTensor is on GPU.
However, this is not how the key metadata is derived in the current implementation of KeyedJaggedTensor. In eager mode, the metadata is derived from the actual data in the lengths tensors, which causes GPU-CPU synchronization if KeyedJaggedTensor is on GPU. In TorchDynamo compile mode, the metadata is not computed and saved during construction to avoid that synchronization. This only defers the synchronization to the point when KeyedJaggedTensor is used in an operation, so it does not really solve the problem.
One might ask why the current KeyedJaggedTensor implementation does not just derive the key metadata from the JaggedTensor metadata, since that would avoid GPU-CPU synchronization altogether. The reason is that constructing this metadata from JaggedTensor objects cannot be traced into computation graph in TorchDynamo compile mode, at least for now, because it involves Python list appending and other bookkeeping that the computation graph cannot support. KeyedJaggedTensor is not a torch.Tensor after all; only the metadata in a torch.Tensor can be symbolically traced in TorchDynamo compile mode. In non-TorchDynamo compile mode, the current implementation may record the construction in the computation graph, but GPU-CPU synchronization is still unavoidable if the JaggedTensors are on GPU.
Consequently, if KeyedJaggedTensor is constructed from GPU JaggedTensors and used in compile mode, GPU-CPU synchronization is inevitable.
An Illustration of KeyedJaggedTensor Metadata Problems
Suppose a batch has two sparse features, user_clicked_item_ids and user_viewed_item_ids, and two samples:
- Sample 0:
user_clicked_item_ids = [10, 11],user_viewed_item_ids = [20] - Sample 1:
user_clicked_item_ids = [12],user_viewed_item_ids = [21, 22, 23]
These can be flattened into one KeyedJaggedTensor as:
1 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor |
Here, lengths stores the number of values per feature per sample, in key order:
user_clicked_item_idsfor sample 0 has length2user_viewed_item_idsfor sample 0 has length1user_clicked_item_idsfor sample 1 has length1user_viewed_item_idsfor sample 1 has length3
This small example is enough to see why KeyedJaggedTensor is data dependent. If a downstream operator needs the output shape for user_viewed_item_ids, it has to know that the lengths for that key are sum([1, 3]) = 4. When those lengths live on GPU, TorchRec has to move them back to CPU to determine the shape, which creates the synchronization.
If we only want to derive metadata from the JaggedTensors themselves, the idea is much simpler. Conceptually, TorchRec could build the key metadata directly from the per-key JaggedTensor objects:
1 | from torchrec.sparse.jagged_tensor import JaggedTensor |
In this illustration, the metadata is derived from the JaggedTensor structure itself, not from the flattened KeyedJaggedTensor values. That is the shape information we would want to preserve, because it can be known before any downstream operation touches the actual sparse values. However, this pattern is not suitable for TorchDynamo compile mode today, because the Python-side bookkeeping needed to collect the metadata is not traceable into the computation graph.
Conclusions
To use KeyedJaggedTensor efficiently in a GPU system, it should be constructed from CPU JaggedTensors and then moved to GPU in eager mode. This usually means KeyedJaggedTensor should be constructed in the data preprocessing stage. In a model running on GPU, KeyedJaggedTensor should only be used as model input. One should avoid constructing KeyedJaggedTensor from GPU JaggedTensors, especially inside the model. In this way, the GPU operations in the model can run asynchronously without GPU-CPU synchronization, which results in the best performance.
References
Synchronizations With TorchRec KeyedJaggedTensor
https://leimao.github.io/blog/TorchRec-KeyedJaggedTensor-Synchronizations/