cryovit.types
Custom types and dataclasses for CryoViT models.
Classes
|
This class represents the model result from a batch of tomograms, organized per tomogram. |
|
|
|
|
|
This class represents the file data for a single tomogram. |
|
Enum of all valid model types. |
|
Enum of all valid CryoET Samples. |
|
- class FileData(tomo_path: Path, label_path: Path | None = None, labels: list[str] | None = None, sample: str | None = None)[source]
Bases:
objectThis class represents the file data for a single tomogram.
- tomo_path
A path to the raw tomogram data.
- Type:
pathlib.Path
- label_path
A path to the segmentation labels. None if not available.
- Type:
pathlib.Path | None
- labels
A list of strings representing the label names. None if not available.
- Type:
list[str] | None
- sample
A string representing the sample. None if not available.
- Type:
str | None
- class TomogramData(sample: str, tomo_name: str, split_id: int | None, data: torch.FloatTensor, label: torch.BoolTensor, aux_data: dict[str, Any] | None = None, *, batch_size, device=None, names=None)[source]
Bases:
object- property device: device
Retrieves the device type of tensor class.
- dumps(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) Any
Saves the tensordict to disk.
This function is a proxy to
memmap().
- classmethod fields()
Return a tuple describing the fields of this dataclass.
Accepts a dataclass or an instance of one. Tuple elements are of type Field.
- classmethod from_tensordict(tensordict: TensorDictBase, non_tensordict: dict | None = None, safe: bool = True) Any
Tensor class wrapper to instantiate a new tensor class object.
- Parameters:
tensordict (TensorDictBase) – Dictionary of tensor types
non_tensordict (dict) – Dictionary with non-tensor and nested tensor class objects
safe (bool) – Whether to raise an error if the tensordict is not a TensorDictBase instance
- get(key: NestedKey, *args, **kwargs)
Gets the value stored with the input key.
- Parameters:
key (str, tuple of str) – key to be queried. If tuple of str it is equivalent to chained calls of getattr.
default – default value if the key is not found in the tensorclass.
- Returns:
value stored with the input key
- classmethod load(prefix: str | Path, *args, **kwargs) Any
Loads a tensordict from disk.
This class method is a proxy to
load_memmap().
- load_(prefix: str | Path, *args, **kwargs)
Loads a tensordict from disk within the current tensordict.
This class method is a proxy to
load_memmap_().
- classmethod load_memmap(prefix: str | Path, device: device | None = None, non_blocking: bool = False, *, out: TensorDictBase | None = None) Any
Loads a memory-mapped tensordict from disk.
- Parameters:
prefix (str or Path to folder) – the path to the folder where the saved tensordict should be fetched.
device (torch.device or equivalent, optional) – if provided, the data will be asynchronously cast to that device. Supports “meta” device, in which case the data isn’t loaded but a set of empty “meta” tensors are created. This is useful to get a sense of the total model size and structure without actually opening any file.
non_blocking (bool, optional) – if
True, synchronize won’t be called after loading tensors on device. Defaults toFalse.out (TensorDictBase, optional) – optional tensordict where the data should be written.
Examples
>>> from tensordict import TensorDict >>> td = TensorDict.fromkeys(["a", "b", "c", ("nested", "e")], 0) >>> td.memmap("./saved_td") >>> td_load = TensorDict.load_memmap("./saved_td") >>> assert (td == td_load).all()
This method also allows loading nested tensordicts.
Examples
>>> nested = TensorDict.load_memmap("./saved_td/nested") >>> assert nested["e"] == 0
A tensordict can also be loaded on “meta” device or, alternatively, as a fake tensor.
Examples
>>> import tempfile >>> td = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}}) >>> with tempfile.TemporaryDirectory() as path: ... td.save(path) ... td_load = TensorDict.load_memmap(path, device="meta") ... print("meta:", td_load) ... from torch._subclasses import FakeTensorMode ... with FakeTensorMode(): ... td_load = TensorDict.load_memmap(path) ... print("fake:", td_load) meta: TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False) fake: TensorDict( fields={ a: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)
- load_state_dict(state_dict: dict[str, Any], strict=True, assign=False, from_flatten=False)
Loads a state_dict attemptedly in-place on the destination tensorclass.
- memmap(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) Any
Writes all tensors onto a corresponding memory-mapped Tensor in a new tensordict.
- Parameters:
prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.
copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True, any existing Tensor will be copied to the new location.
- Keyword Arguments:
num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.
return_early (bool, optional) – if
Trueandnum_threads>0, the method will return a future of the tensordict.share_non_tensor (bool, optional) – if
True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults toFalse.existsok (bool, optional) – if
False, an exception will be raised if a tensor already exists in the same path. Defaults toTrue.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False, because cross-process identity is not guaranteed anymore.- Returns:
A new tensordict with the tensors stored on disk if
return_early=False, otherwise aTensorDictFutureinstance.
Note
Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.
- memmap_(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) Any
Writes all tensors onto a corresponding memory-mapped Tensor, in-place.
- Parameters:
prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.
copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True, any existing Tensor will be copied to the new location.
- Keyword Arguments:
num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.
return_early (bool, optional) – if
Trueandnum_threads>0, the method will return a future of the tensordict. The resulting tensordict can be queried using future.result().share_non_tensor (bool, optional) – if
True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults toFalse.existsok (bool, optional) – if
False, an exception will be raised if a tensor already exists in the same path. Defaults toTrue.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False, because cross-process identity is not guaranteed anymore.- Returns:
self if
return_early=False, otherwise aTensorDictFutureinstance.
Note
Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.
- memmap_like(prefix: str | None = None, copy_existing: bool = False, *, existsok: bool = True, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) Any
Creates a contentless Memory-mapped tensordict with the same shapes as the original one.
- Parameters:
prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.
copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True, any existing Tensor will be copied to the new location.
- Keyword Arguments:
num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.
return_early (bool, optional) – if
Trueandnum_threads>0, the method will return a future of the tensordict.share_non_tensor (bool, optional) – if
True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults toFalse.existsok (bool, optional) – if
False, an exception will be raised if a tensor already exists in the same path. Defaults toTrue.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False, because cross-process identity is not guaranteed anymore.- Returns:
A new
TensorDictinstance with data stored as memory-mapped tensors ifreturn_early=False, otherwise aTensorDictFutureinstance.
Note
This is the recommended method to write a set of large buffers on disk, as
memmap_()will copy the information, which can be slow for large content.Examples
>>> td = TensorDict({ ... "a": torch.zeros((3, 64, 64), dtype=torch.uint8), ... "b": torch.zeros(1, dtype=torch.int64), ... }, batch_size=[]).expand(1_000_000) # expand does not allocate new memory >>> buffer = td.memmap_like("/path/to/dataset")
- memmap_refresh_()
Refreshes the content of the memory-mapped tensordict if it has a
saved_path.This method will raise an exception if no path is associated with it.
- save(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) Any
Saves the tensordict to disk.
This function is a proxy to
memmap().
- set(key: NestedKey, value: Any, inplace: bool = False, non_blocking: bool = False)
Sets a new key-value pair.
- Parameters:
key (str, tuple of str) – name of the key to be set. If tuple of str it is equivalent to chained calls of getattr followed by a final setattr.
value (Any) – value to be stored in the tensorclass
inplace (bool, optional) – if
True, set will tentatively try to update the value in-place. IfFalseor if the key isn’t present, the value will be simply written at its destination.
- Returns:
self
- state_dict(destination=None, prefix='', keep_vars=False, flatten=False) dict[str, Any]
Returns a state_dict dictionary that can be used to save and load data from a tensorclass.
- to_tensordict(*, retain_none: bool | None = None) TensorDict
Convert the tensorclass into a regular TensorDict.
Makes a copy of all entries. Memmap and shared memory tensors are converted to regular tensors.
- Parameters:
retain_none (bool) – if
True, theNonevalues will be written in the tensordict. Otherwise they will be discrarded. Default:True.- Returns:
A new TensorDict object containing the same values as the tensorclass.
- unbind(dim: int)
Returns a tuple of indexed tensorclass instances unbound along the indicated dimension.
Resulting tensorclass instances will share the storage of the initial tensorclass instance.
- class BatchedTomogramMetadata(samples: list[str], tomo_names: list[str], unique_id: torch.LongTensor, split_id: list[torch.IntTensor] | None, *, batch_size, device=None, names=None)[source]
Bases:
object- property device: device
Retrieves the device type of tensor class.
- dumps(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) Any
Saves the tensordict to disk.
This function is a proxy to
memmap().
- classmethod fields()
Return a tuple describing the fields of this dataclass.
Accepts a dataclass or an instance of one. Tuple elements are of type Field.
- classmethod from_tensordict(tensordict: TensorDictBase, non_tensordict: dict | None = None, safe: bool = True) Any
Tensor class wrapper to instantiate a new tensor class object.
- Parameters:
tensordict (TensorDictBase) – Dictionary of tensor types
non_tensordict (dict) – Dictionary with non-tensor and nested tensor class objects
safe (bool) – Whether to raise an error if the tensordict is not a TensorDictBase instance
- get(key: NestedKey, *args, **kwargs)
Gets the value stored with the input key.
- Parameters:
key (str, tuple of str) – key to be queried. If tuple of str it is equivalent to chained calls of getattr.
default – default value if the key is not found in the tensorclass.
- Returns:
value stored with the input key
- classmethod load(prefix: str | Path, *args, **kwargs) Any
Loads a tensordict from disk.
This class method is a proxy to
load_memmap().
- load_(prefix: str | Path, *args, **kwargs)
Loads a tensordict from disk within the current tensordict.
This class method is a proxy to
load_memmap_().
- classmethod load_memmap(prefix: str | Path, device: device | None = None, non_blocking: bool = False, *, out: TensorDictBase | None = None) Any
Loads a memory-mapped tensordict from disk.
- Parameters:
prefix (str or Path to folder) – the path to the folder where the saved tensordict should be fetched.
device (torch.device or equivalent, optional) – if provided, the data will be asynchronously cast to that device. Supports “meta” device, in which case the data isn’t loaded but a set of empty “meta” tensors are created. This is useful to get a sense of the total model size and structure without actually opening any file.
non_blocking (bool, optional) – if
True, synchronize won’t be called after loading tensors on device. Defaults toFalse.out (TensorDictBase, optional) – optional tensordict where the data should be written.
Examples
>>> from tensordict import TensorDict >>> td = TensorDict.fromkeys(["a", "b", "c", ("nested", "e")], 0) >>> td.memmap("./saved_td") >>> td_load = TensorDict.load_memmap("./saved_td") >>> assert (td == td_load).all()
This method also allows loading nested tensordicts.
Examples
>>> nested = TensorDict.load_memmap("./saved_td/nested") >>> assert nested["e"] == 0
A tensordict can also be loaded on “meta” device or, alternatively, as a fake tensor.
Examples
>>> import tempfile >>> td = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}}) >>> with tempfile.TemporaryDirectory() as path: ... td.save(path) ... td_load = TensorDict.load_memmap(path, device="meta") ... print("meta:", td_load) ... from torch._subclasses import FakeTensorMode ... with FakeTensorMode(): ... td_load = TensorDict.load_memmap(path) ... print("fake:", td_load) meta: TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False) fake: TensorDict( fields={ a: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)
- load_state_dict(state_dict: dict[str, Any], strict=True, assign=False, from_flatten=False)
Loads a state_dict attemptedly in-place on the destination tensorclass.
- memmap(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) Any
Writes all tensors onto a corresponding memory-mapped Tensor in a new tensordict.
- Parameters:
prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.
copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True, any existing Tensor will be copied to the new location.
- Keyword Arguments:
num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.
return_early (bool, optional) – if
Trueandnum_threads>0, the method will return a future of the tensordict.share_non_tensor (bool, optional) – if
True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults toFalse.existsok (bool, optional) – if
False, an exception will be raised if a tensor already exists in the same path. Defaults toTrue.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False, because cross-process identity is not guaranteed anymore.- Returns:
A new tensordict with the tensors stored on disk if
return_early=False, otherwise aTensorDictFutureinstance.
Note
Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.
- memmap_(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) Any
Writes all tensors onto a corresponding memory-mapped Tensor, in-place.
- Parameters:
prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.
copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True, any existing Tensor will be copied to the new location.
- Keyword Arguments:
num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.
return_early (bool, optional) – if
Trueandnum_threads>0, the method will return a future of the tensordict. The resulting tensordict can be queried using future.result().share_non_tensor (bool, optional) – if
True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults toFalse.existsok (bool, optional) – if
False, an exception will be raised if a tensor already exists in the same path. Defaults toTrue.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False, because cross-process identity is not guaranteed anymore.- Returns:
self if
return_early=False, otherwise aTensorDictFutureinstance.
Note
Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.
- memmap_like(prefix: str | None = None, copy_existing: bool = False, *, existsok: bool = True, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) Any
Creates a contentless Memory-mapped tensordict with the same shapes as the original one.
- Parameters:
prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.
copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True, any existing Tensor will be copied to the new location.
- Keyword Arguments:
num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.
return_early (bool, optional) – if
Trueandnum_threads>0, the method will return a future of the tensordict.share_non_tensor (bool, optional) – if
True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults toFalse.existsok (bool, optional) – if
False, an exception will be raised if a tensor already exists in the same path. Defaults toTrue.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False, because cross-process identity is not guaranteed anymore.- Returns:
A new
TensorDictinstance with data stored as memory-mapped tensors ifreturn_early=False, otherwise aTensorDictFutureinstance.
Note
This is the recommended method to write a set of large buffers on disk, as
memmap_()will copy the information, which can be slow for large content.Examples
>>> td = TensorDict({ ... "a": torch.zeros((3, 64, 64), dtype=torch.uint8), ... "b": torch.zeros(1, dtype=torch.int64), ... }, batch_size=[]).expand(1_000_000) # expand does not allocate new memory >>> buffer = td.memmap_like("/path/to/dataset")
- memmap_refresh_()
Refreshes the content of the memory-mapped tensordict if it has a
saved_path.This method will raise an exception if no path is associated with it.
- save(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) Any
Saves the tensordict to disk.
This function is a proxy to
memmap().
- set(key: NestedKey, value: Any, inplace: bool = False, non_blocking: bool = False)
Sets a new key-value pair.
- Parameters:
key (str, tuple of str) – name of the key to be set. If tuple of str it is equivalent to chained calls of getattr followed by a final setattr.
value (Any) – value to be stored in the tensorclass
inplace (bool, optional) – if
True, set will tentatively try to update the value in-place. IfFalseor if the key isn’t present, the value will be simply written at its destination.
- Returns:
self
- state_dict(destination=None, prefix='', keep_vars=False, flatten=False) dict[str, Any]
Returns a state_dict dictionary that can be used to save and load data from a tensorclass.
- to_tensordict(*, retain_none: bool | None = None) TensorDict
Convert the tensorclass into a regular TensorDict.
Makes a copy of all entries. Memmap and shared memory tensors are converted to regular tensors.
- Parameters:
retain_none (bool) – if
True, theNonevalues will be written in the tensordict. Otherwise they will be discrarded. Default:True.- Returns:
A new TensorDict object containing the same values as the tensorclass.
- unbind(dim: int)
Returns a tuple of indexed tensorclass instances unbound along the indicated dimension.
Resulting tensorclass instances will share the storage of the initial tensorclass instance.
- class BatchedTomogramData(tomo_batch: torch.Tensor, tomo_sizes: torch.Tensor, labels: torch.Tensor, metadata: cryovit.types.BatchedTomogramMetadata, min_slices: int, aux_data: dict[str, list[Any]] | None = None, *, batch_size, device=None, names=None)[source]
Bases:
object- property num_tomos: int
Returns the number of tomograms in the batch.
- property num_slices: int
Returns the maximum number of slices in the batch.
- index_to_flat_batch(idx: int) Tensor[source]
Returns a [BxD] tensor containing the indices corresponding to a certain slice in a flat batch tensor.
- property flat_tomo_batch: Tensor
Returns a [[BxD]xCxHxW] tensor from a [BxDxCxHxW] tensor (C is optional).
- property device: device
Retrieves the device type of tensor class.
- dumps(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) Any
Saves the tensordict to disk.
This function is a proxy to
memmap().
- classmethod fields()
Return a tuple describing the fields of this dataclass.
Accepts a dataclass or an instance of one. Tuple elements are of type Field.
- classmethod from_tensordict(tensordict: TensorDictBase, non_tensordict: dict | None = None, safe: bool = True) Any
Tensor class wrapper to instantiate a new tensor class object.
- Parameters:
tensordict (TensorDictBase) – Dictionary of tensor types
non_tensordict (dict) – Dictionary with non-tensor and nested tensor class objects
safe (bool) – Whether to raise an error if the tensordict is not a TensorDictBase instance
- get(key: NestedKey, *args, **kwargs)
Gets the value stored with the input key.
- Parameters:
key (str, tuple of str) – key to be queried. If tuple of str it is equivalent to chained calls of getattr.
default – default value if the key is not found in the tensorclass.
- Returns:
value stored with the input key
- classmethod load(prefix: str | Path, *args, **kwargs) Any
Loads a tensordict from disk.
This class method is a proxy to
load_memmap().
- load_(prefix: str | Path, *args, **kwargs)
Loads a tensordict from disk within the current tensordict.
This class method is a proxy to
load_memmap_().
- classmethod load_memmap(prefix: str | Path, device: device | None = None, non_blocking: bool = False, *, out: TensorDictBase | None = None) Any
Loads a memory-mapped tensordict from disk.
- Parameters:
prefix (str or Path to folder) – the path to the folder where the saved tensordict should be fetched.
device (torch.device or equivalent, optional) – if provided, the data will be asynchronously cast to that device. Supports “meta” device, in which case the data isn’t loaded but a set of empty “meta” tensors are created. This is useful to get a sense of the total model size and structure without actually opening any file.
non_blocking (bool, optional) – if
True, synchronize won’t be called after loading tensors on device. Defaults toFalse.out (TensorDictBase, optional) – optional tensordict where the data should be written.
Examples
>>> from tensordict import TensorDict >>> td = TensorDict.fromkeys(["a", "b", "c", ("nested", "e")], 0) >>> td.memmap("./saved_td") >>> td_load = TensorDict.load_memmap("./saved_td") >>> assert (td == td_load).all()
This method also allows loading nested tensordicts.
Examples
>>> nested = TensorDict.load_memmap("./saved_td/nested") >>> assert nested["e"] == 0
A tensordict can also be loaded on “meta” device or, alternatively, as a fake tensor.
Examples
>>> import tempfile >>> td = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}}) >>> with tempfile.TemporaryDirectory() as path: ... td.save(path) ... td_load = TensorDict.load_memmap(path, device="meta") ... print("meta:", td_load) ... from torch._subclasses import FakeTensorMode ... with FakeTensorMode(): ... td_load = TensorDict.load_memmap(path) ... print("fake:", td_load) meta: TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False) fake: TensorDict( fields={ a: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)
- load_state_dict(state_dict: dict[str, Any], strict=True, assign=False, from_flatten=False)
Loads a state_dict attemptedly in-place on the destination tensorclass.
- memmap(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) Any
Writes all tensors onto a corresponding memory-mapped Tensor in a new tensordict.
- Parameters:
prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.
copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True, any existing Tensor will be copied to the new location.
- Keyword Arguments:
num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.
return_early (bool, optional) – if
Trueandnum_threads>0, the method will return a future of the tensordict.share_non_tensor (bool, optional) – if
True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults toFalse.existsok (bool, optional) – if
False, an exception will be raised if a tensor already exists in the same path. Defaults toTrue.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False, because cross-process identity is not guaranteed anymore.- Returns:
A new tensordict with the tensors stored on disk if
return_early=False, otherwise aTensorDictFutureinstance.
Note
Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.
- memmap_(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) Any
Writes all tensors onto a corresponding memory-mapped Tensor, in-place.
- Parameters:
prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.
copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True, any existing Tensor will be copied to the new location.
- Keyword Arguments:
num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.
return_early (bool, optional) – if
Trueandnum_threads>0, the method will return a future of the tensordict. The resulting tensordict can be queried using future.result().share_non_tensor (bool, optional) – if
True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults toFalse.existsok (bool, optional) – if
False, an exception will be raised if a tensor already exists in the same path. Defaults toTrue.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False, because cross-process identity is not guaranteed anymore.- Returns:
self if
return_early=False, otherwise aTensorDictFutureinstance.
Note
Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.
- memmap_like(prefix: str | None = None, copy_existing: bool = False, *, existsok: bool = True, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) Any
Creates a contentless Memory-mapped tensordict with the same shapes as the original one.
- Parameters:
prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.
copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True, any existing Tensor will be copied to the new location.
- Keyword Arguments:
num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.
return_early (bool, optional) – if
Trueandnum_threads>0, the method will return a future of the tensordict.share_non_tensor (bool, optional) – if
True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults toFalse.existsok (bool, optional) – if
False, an exception will be raised if a tensor already exists in the same path. Defaults toTrue.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False, because cross-process identity is not guaranteed anymore.- Returns:
A new
TensorDictinstance with data stored as memory-mapped tensors ifreturn_early=False, otherwise aTensorDictFutureinstance.
Note
This is the recommended method to write a set of large buffers on disk, as
memmap_()will copy the information, which can be slow for large content.Examples
>>> td = TensorDict({ ... "a": torch.zeros((3, 64, 64), dtype=torch.uint8), ... "b": torch.zeros(1, dtype=torch.int64), ... }, batch_size=[]).expand(1_000_000) # expand does not allocate new memory >>> buffer = td.memmap_like("/path/to/dataset")
- memmap_refresh_()
Refreshes the content of the memory-mapped tensordict if it has a
saved_path.This method will raise an exception if no path is associated with it.
- save(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) Any
Saves the tensordict to disk.
This function is a proxy to
memmap().
- set(key: NestedKey, value: Any, inplace: bool = False, non_blocking: bool = False)
Sets a new key-value pair.
- Parameters:
key (str, tuple of str) – name of the key to be set. If tuple of str it is equivalent to chained calls of getattr followed by a final setattr.
value (Any) – value to be stored in the tensorclass
inplace (bool, optional) – if
True, set will tentatively try to update the value in-place. IfFalseor if the key isn’t present, the value will be simply written at its destination.
- Returns:
self
- state_dict(destination=None, prefix='', keep_vars=False, flatten=False) dict[str, Any]
Returns a state_dict dictionary that can be used to save and load data from a tensorclass.
- to_tensordict(*, retain_none: bool | None = None) TensorDict
Convert the tensorclass into a regular TensorDict.
Makes a copy of all entries. Memmap and shared memory tensors are converted to regular tensors.
- Parameters:
retain_none (bool) – if
True, theNonevalues will be written in the tensordict. Otherwise they will be discrarded. Default:True.- Returns:
A new TensorDict object containing the same values as the tensorclass.
- unbind(dim: int)
Returns a tuple of indexed tensorclass instances unbound along the indicated dimension.
Resulting tensorclass instances will share the storage of the initial tensorclass instance.
- class BatchedModelResult(num_tomos: int, samples: list[str], tomo_names: list[str], split_id: list[int] | None, data: list[ndarray[tuple[int, ...], dtype[float32]]], label: list[ndarray[tuple[int, ...], dtype[uint8]]], preds: list[ndarray[tuple[int, ...], dtype[float32]]], losses: dict[str, float], metrics: dict[str, float], aux_data: dict[str, list[Any]] | None = None)[source]
Bases:
objectThis class represents the model result from a batch of tomograms, organized per tomogram.
- num_tomos
The number of tomograms in the batch.
- Type:
int
- samples
The sample for each tomogram in the batch.
- Type:
list[str]
- tomo_names
The file name for each tomogram in the batch.
- Type:
list[str]
- split_id
The optional split id for each tomogram in the batch.
- Type:
list[int] | None
- data
The raw tomogram data for each tomogram in the batch.
- Type:
list[numpy.ndarray[tuple[int, …], numpy.dtype[numpy.float32]]]
- label
The true segmentation labels for each tomogram in the batch.
- Type:
list[numpy.ndarray[tuple[int, …], numpy.dtype[numpy.uint8]]]
- preds
The model predictions for each tomogram in the batch.
- Type:
list[numpy.ndarray[tuple[int, …], numpy.dtype[numpy.float32]]]
- losses
A dictionary of losses for each tomogram in the batch.
- Type:
dict[str, float]
- metrics
A dictionary of metrics for each tomogram in the batch.
- Type:
dict[str, float]
- aux_data
An optional dictionary containing auxiliary data for each tomogram in the batch.
- Type:
dict[str, list[Any]] | None