# Copyright Contributors to the OpenVDB Project
# SPDX-License-Identifier: Apache-2.0
#
from typing import Any
from fvdb_reality_capture.sfm_scene import SfmScene
from .base_transform import REGISTERED_TRANSFORMS, BaseTransform, transform
[docs]
@transform
class Compose(BaseTransform):
"""
A :class:`~base_transform.BaseTransform` that composes multiple transforms together in sequence.
This is useful for encoding a sequence of transforms into a single object.
The transforms are applied in the order they are provided, allowing for complex data processing pipelines.
Example usage:
.. code-block:: python
# Example usage:
from fvdb_reality_capture import transforms
from fvdb_reality_capture.sfm_scene import SfmScene
scene_transform = transforms.Compose(
transforms.NormalizeScene("pca"),
transforms.DownsampleImages(4),
)
input_scene: SfmScene = ... # Load or create an SfmScene
transformed_scene: SfmScene = scene_transform(input_scene)
"""
version = "1.0.0"
[docs]
def __init__(self, *transforms):
"""
Initialize the Compose transform with a sequence of transforms.
Args:
*transforms (tuple[BaseTransform...]): A tuple of :class:`~base_transform.BaseTransform` instances
to compose.
"""
super().__init__()
self.transforms = transforms
for transform in self.transforms:
if not isinstance(transform, BaseTransform):
raise TypeError(f"Expected a BaseTransform instance, got {type(transform)} instead.")
[docs]
def __call__(self, input_scene: SfmScene) -> SfmScene:
"""
Return a new :class:`~fvdb_reality_capture.sfm_scene.SfmScene` which is the result of applying the composed
transforms sequentially to the input scene.
Args:
input_scene (SfmScene): The input :class:`~fvdb_reality_capture.sfm_scene.SfmScene` to transform.
Returns:
output_scene (SfmScene): A new :class:`~fvdb_reality_capture.sfm_scene.SfmScene` that has been transformed
by all the composed transforms.
"""
for transform in self.transforms:
input_scene = transform(input_scene)
return input_scene
[docs]
def state_dict(self) -> dict[str, Any]:
"""
Return the state of the :class:`Compose` transform for serialization.
You can use this state dictionary to recreate the transform using :meth:`from_state_dict`.
Returns:
state_dict (dict[str, Any]): A dictionary containing information to serialize/deserialize the transform.
"""
return {
"name": self.name(),
"version": self.version,
"transforms": [
{"name": transform.name(), "state": transform.state_dict()} for transform in self.transforms
],
}
[docs]
@staticmethod
def name() -> str:
"""
Return the name of the :class:`Compose` transform. *i.e.* ``"Compose"``.
Returns:
str: The name of the :class:`Compose` transform. *i.e.* ``"Compose"``.
"""
return "Compose"
[docs]
@staticmethod
def from_state_dict(state_dict: dict[str, Any]) -> "Compose":
"""
Create a :class:`Compose` transform from a state dictionary generated with :meth:`state_dict`.
Args:
state_dict (dict[str, Any]): A dictionary containing information to serialize/deserialize the transform.
Returns:
transform (:class:`Compose`): An instance of the :class:`Compose` transform loaded from the state dictionary.
"""
if state_dict["name"] != "Compose":
raise ValueError(f"Expected state_dict with name 'Compose', got {state_dict['name']} instead.")
if "transforms" not in state_dict:
raise ValueError("State dictionary must contain 'transforms' key.")
if not isinstance(state_dict["transforms"], list):
raise TypeError(f"Expected 'transforms' to be a list, got {type(state_dict['transforms'])} instead.")
transforms = []
for transform_state in state_dict["transforms"]:
if not isinstance(transform_state, dict):
raise TypeError(f"Expected each transform state to be a dict, got {type(transform_state)} instead.")
if "name" not in transform_state:
raise ValueError("Each transform state must contain a 'name' key.")
if "state" not in transform_state:
raise ValueError("Each transform state must contain a 'state' key.")
StateDictType = REGISTERED_TRANSFORMS.get(transform_state["name"], None)
if StateDictType is None:
raise ValueError(
f"Transform '{transform_state['name']}' is not registered. Transform classes must be registered "
f"with the `transform` decorator which will be called when the transform is defined. "
f"Ensure the transform class uses the `transform` decorator and was imported before calling from_state_dict."
)
transforms.append(StateDictType.from_state_dict(transform_state["state"]))
return Compose(*transforms)