# Copyright Contributors to the OpenVDB Project
# SPDX-License-Identifier: Apache-2.0
#
import os
from typing import Any, Sequence
import setuptools
from torch.utils import cpp_extension
import fvdb
[docs]
def fvdbCudaExtension(name: str, sources: Sequence[str], *args: Any, **kwargs: Any) -> setuptools.Extension:
"""
Utility function for creating pytorch extensions that depend on fvdb. You then have access to all fVDB's internal
headers to program with. Example usage:
.. code-block:: python
from fvdb.utils import FVDBExtension
ext = FVDBExtension(
name='my_extension',
sources=['my_extension.cpp'],
extra_compile_args={'cxx': ['-std=c++17']},
libraries=['mylib'],
)
Args:
name (str): The name of the extension.
sources (Sequence[str]): The list of source files.
args (list[Any]): Other arguments to pass to :func:`torch.utils.cpp_extension.CppExtension`.
kwargs (dict): Other keyword arguments to pass to :func:`torch.utils.cpp_extension.CppExtension`.
Returns:
cpp_extension (setuptools.Extension) A :class:`setuptools.Extension` object which can be used
to build a PyTorch C++ extension that depends on fVDB.
"""
libraries = kwargs.get("libraries", [])
libraries.append("fvdb")
kwargs["libraries"] = libraries
library_dirs = kwargs.get("library_dirs", [])
library_dirs.append(os.path.dirname(fvdb.__file__))
kwargs["library_dirs"] = library_dirs
include_dirs = kwargs.get("include_dirs", [])
include_dirs.append(os.path.join(os.path.dirname(fvdb.__file__), "include"))
# We also need to add this because fvdb internally will refer to their headers without the fvdb/ prefix.
include_dirs.append(os.path.join(os.path.dirname(fvdb.__file__), "include/fvdb"))
kwargs["include_dirs"] = include_dirs
extra_link_args = kwargs.get("extra_link_args", [])
extra_link_args.append(f"-Wl,-rpath={os.path.dirname(fvdb.__file__)}")
kwargs["extra_link_args"] = extra_link_args
extra_compile_args = kwargs.get("extra_compile_args", {})
extra_compile_args["nvcc"] = extra_compile_args.get("nvcc", [])
if "--extended-lambda" not in extra_compile_args["nvcc"]:
extra_compile_args["nvcc"].append("--extended-lambda")
kwargs["extra_compile_args"] = extra_compile_args
return cpp_extension.CUDAExtension(name, sources, *args, **kwargs)