Overview
PyTorch Lightning versions 2.4.0 and earlier do not use any verification mechanisms to ensure that model files are safe to load before loading them. Users of PyTorch Lightning should use caution when loading models from unknown or unmanaged sources.
Description
PyTorch Lightning, a high-level framework built on top of PyTorch, is designed to streamline deep learning model training, scaling, and deployment. PyTorch Lightning is widely used in AI research and production environments, often integrating with various cloud and distributed computing platforms to manage large-scale machine learning workloads.
PyTorch Lightning contains multiple vulnerabilities related to the deserialization of untrusted data (CWE-502). These vulnerabilities arise from the unsafe use of torch.load(), which is used to deserialize model checkpoints, configurations, and sometimes metadata. While torch.load() provides an optional weights_only=True parameter to mitigate the risks of loading arbitrary code, PyTorch Lightning does not require or enforce this safeguard as a principal security requirement for the product.
Kasimir Schulz of HiddenLayer identified and reported the following five vulnerabilities:
The DeepSpeed integration in PyTorch Lightning loads optimizer states and model checkpoints without enforcing safe deserialization practices. It does not validate the integrity or origin of serialized data before passing it to torch.load(), allowing deserialization of arbitrary objects.
The PickleSerializer class directly utilizes Python’s pickle module to handle data serialization and deserialization. Since pickle inherently allows execution of embedded code during deserialization, any untrusted or manipulated input processed by this class can introduce security risks.
The _load_distributed_checkpoint component is responsible for handling distributed training checkpoints. It processes model state data across multiple nodes, but it does not include safeguards to verify or restrict the content being deserialized.
The _lazy_lo ..
Support the originator by clicking the read the rest link below.