JaggedTensor
- class fvdb.JaggedTensor(tensors: Tensor | Sequence[Tensor] | Sequence[Sequence[Tensor]] | None = None, *, impl: JaggedTensor | None = None)[source]
A jagged (ragged) tensor data structure with support for efficient operations.
JaggedTensorrepresents sequences of tensors with varying lengths, stored efficiently in a flat contiguous format with associated index/offset structures. This is useful for batch processing of variable-length sequences on the GPU while maintaining memory efficiency and enabling vectorized operations.A
JaggedTensorcan represent:1. A sequence of tensors with varying shapes along the first dimension. These are usually written as
[tensor_1, tensor_2, ..., tensor_N]where eachtensor_ican have a different shape along the first dimension.2. Nested sequences (list of lists) with varying lengths at multiple levels. These are usually written as
[[tensor_11, tensor_12, ...], [tensor_21, tensor_22, ...], ...]where both the outer and inner sequences can have varying lengths, and eachtensor_ijcan have a different shape along the first dimension.The
JaggedTensordata structure consists of the following components:JaggedTensor integrates with PyTorch through __torch_function__, allowing many torch operations to work directly on jagged tensors while preserving the jagged structure. Operations that preserve the leading (flattened) dimension work seamlessly, while shape-changing operations require specialized j* methods.
Example usage:
# Create a JaggedTensor from a list of tensors jt = JaggedTensor.from_list_of_tensors([torch.randn(3, 4), torch.randn(2, 4), torch.randn(5, 4)])
# Perform element-wise operations jt2 = jt + 1.0 jt3 = torch.relu(jt2)
# Access jagged data and structure data = jt3.jdata offsets = jt3.joffsets
# Get the first tensor in the jagged sequence first_tensor = jt3[0]
# Get the last tensor in the jagged sequence last_tensor = jt3[-1]
Note
The
JaggedTensorshould be constructed using the explicit classmethods: -from_tensor()for a single tensor -from_list_of_tensors()for a list of tensors -from_list_of_lists_of_tensors()for nested lists of tensors -from_data_and_indices()for pre-computed flat format -from_data_and_offsets()for pre-computed flat format with offsets -from_empty()for uninitialized data -from_ones()for data filled with ones -from_rand()for data filled with random values from uniform distribution [0, 1) -from_randn()for data filled with random values from normal distribution (mean 0, variance 1) -from_zeros()for data filled with zeros -from_cat()for concatenation of jagged tensors- abs() JaggedTensor[source]
Compute the absolute value element-wise.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensorwith absolute values.
- abs_() JaggedTensor[source]
Compute the absolute value element-wise in-place.
- Returns:
jagged_tensor (JaggedTensor) – The modified
JaggedTensor(self).
- ceil() JaggedTensor[source]
Round elements up to the nearest integer.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensorwith ceiling applied.
- ceil_() JaggedTensor[source]
Round elements up to the nearest integer in-place.
- Returns:
jagged_tensor (JaggedTensor) – The modified
JaggedTensor(self).
- clone() JaggedTensor[source]
Create a deep copy of the JaggedTensor.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensorwith copied data and structure.
- cpu() JaggedTensor[source]
Move the JaggedTensor to CPU memory.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensoron CPU device.
- cuda() JaggedTensor[source]
Move the JaggedTensor to CUDA (GPU) memory.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensoron CUDA device.
- detach() JaggedTensor[source]
Detach the JaggedTensor from the autograd graph.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensordetached from the computation graph.
- property device: device
Device where this
JaggedTensoris stored.- Returns:
torch.device – The device of this
JaggedTensor.
- double() JaggedTensor[source]
Convert elements to double (float64) dtype.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensorwith double precision.
- property dtype: dtype
Data type of the elements in this
JaggedTensor.- Returns:
torch.dtype – The data type of this
JaggedTensor.
- property edim: int
Dimensionality of the element (regular) structure.
For example, if each tensor in the jagged sequence has shape
(?, 4, 5), thenedimwill be2since there are two regular dimensions per element.- Returns:
int – The dimensionality of the element structure.
- property eshape: list[int]
Shape of the element dimensions.
For example, if each tensor in the jagged sequence has shape
(?, 4, 5), theneshapewill be[4, 5].- Returns:
list[int] – The shape of the element dimensions.
- float() JaggedTensor[source]
Convert elements to float (float32) dtype.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensorwith float32 precision.
- floor() JaggedTensor[source]
Round elements down to the nearest integer.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensorwith floor applied.
- floor_() JaggedTensor[source]
Round elements down to the nearest integer in-place.
- Returns:
jagged_tensor (JaggedTensor) – The modified
JaggedTensor(self).
- classmethod from_cat(things_to_cat: Sequence[JaggedTensor], dim: int | None = None) JaggedTensor[source]
Create a JaggedTensor by concatenating a sequence of JaggedTensors.
- Parameters:
things_to_cat – The sequence of JaggedTensors to concatenate.
dim – The dimension to concatenate along, optional.
- classmethod from_data_and_indices(data: Tensor, indices: Tensor, num_tensors: int) JaggedTensor[source]
Create a
JaggedTensorfrom flattened data and per-element indices.Example
data = torch.tensor([1, 2, 3, 4, 5, 6]) indices = torch.tensor([0, 0, 1, 1, 1, 2])
jt = JaggedTensor.from_data_and_indices(data, indices, num_tensors=3)
# jt represents: # - tensor 0: [1, 2] # - tensor 1: [3, 4, 5] # - tensor 2: [6]
- Parameters:
data (torch.Tensor) – Flattened data tensor containing all elements. Shape:
(total_elements, ...).indices (torch.Tensor) – Index tensor mapping each element to its parent tensor. Shape:
(total_elements,). Values in range[0, num_tensors).num_tensors (int) – Total number of tensors in the sequence.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensorconstructed from the data and indices.
- classmethod from_data_and_offsets(data: Tensor, offsets: Tensor) JaggedTensor[source]
Create a
JaggedTensorfrom flattened data and offset array.Offsets define boundaries between tensors in the flattened data array. Tensor
icontains elementsdata[offsets[i]:offsets[i+1]].Example:
data = torch.tensor([1, 2, 3, 4, 5, 6]) offsets = torch.tensor([0, 2, 5, 6]) # 3 tensors: [0:2], [2:5], [5:6] jt = JaggedTensor.from_data_and_offsets(data, offsets) # jt represents: # - tensor 0: [1, 2] # - tensor 1: [3, 4, 5] # - tensor 2: [6]
- Parameters:
data (torch.Tensor) – Flattened data tensor containing all elements. Shape:
(total_elements, ...).offsets (torch.Tensor) – Offset tensor marking tensor boundaries. Shape:
(num_tensors + 1,). Must be monotonically increasing.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensorconstructed from thedataandoffsets.
- classmethod from_data_indices_and_list_ids(data: Tensor, indices: Tensor, list_ids: Tensor, num_tensors: int) JaggedTensor[source]
Create a
JaggedTensorfrom data, per-element indices, and list IDs.This function validates that data, indices, list_ids, and num_tensors are compatible. The offsets are computed internally from the indices.
Example (ldim == 1, list of tensors):
data = torch.tensor([1, 2, 3, 4, 5]) indices = torch.tensor([0, 0, 1, 1, 1]) # Elements 0-1 in tensor 0, # elements 2-4 in tensor 1 list_ids = torch.empty((0, 1), dtype=torch.int) # Empty for ldim == 1 jt = JaggedTensor.from_data_indices_and_list_ids(data, indices, list_ids, 2) # Result: [[1, 2], [3, 4, 5]]
Example (ldim == 2, list of lists of tensors):
data = torch.tensor([1, 2, 3, 4, 5, 6]) indices = torch.tensor([0, 0, 1, 1, 1, 2]) # 3 tensors list_ids = torch.tensor([[0, 0], [0, 1], [1, 0]]) # Outer/inner indices jt = JaggedTensor.from_data_indices_and_list_ids(data, indices, list_ids, 3) # Result: [[[1, 2], [3, 4, 5]], [[6]]]
- Parameters:
data (torch.Tensor) – Flattened data tensor containing all elements. Shape:
(total_elements, ...).indices (torch.Tensor) – Index tensor mapping each element to its tensor index (0 to num_tensors-1). Shape:
(total_elements,), or empty if num_tensors == 1.list_ids (torch.Tensor) – Tensor defining the hierarchical position of each tensor. For ldim == 1: shape
(num_tensors, 1)or empty tensor with shape(0, 1). Empty tensor assumes a single, naturally ordered list of tensors. For ldim == 2: shape(num_tensors, 2)where each row is(outer_idx, inner_idx).num_tensors (int) – Total number of tensors.
- Returns:
jagged_tensor (JaggedTensor) – A
JaggedTensordefined by the data, indices, and list ids.
- classmethod from_data_offsets_and_list_ids(data: Tensor, offsets: Tensor, list_ids: Tensor) JaggedTensor[source]
Create a
JaggedTensorfrom data, offsets, and list IDs.This function validates that data, offsets, and list_ids are compatible. The per-element indices are computed internally from the offsets.
Example (ldim == 1, list of tensors):
data = torch.tensor([1, 2, 3, 4, 5]) offsets = torch.tensor([0, 2, 5]) # 2 tensors: data[0:2], data[2:5] list_ids = torch.empty((0, 1), dtype=torch.int) # Empty for ldim == 1 jt = JaggedTensor.from_data_offsets_and_list_ids(data, offsets, list_ids) # Result: [[1, 2], [3, 4, 5]]
Example (ldim == 2, list of lists of tensors):
data = torch.tensor([1, 2, 3, 4, 5, 6]) offsets = torch.tensor([0, 2, 5, 6]) # 3 tensors: [0:2], [2:5], [5:6] list_ids = torch.tensor([[0, 0], [0, 1], [1, 0]]) # list_ids maps: tensor 0 -> list[0][0], tensor 1 -> list[0][1], tensor 2 -> list[1][0] jt = JaggedTensor.from_data_offsets_and_list_ids(data, offsets, list_ids) # Result: [[[1, 2], [3, 4, 5]], [[6]]]
- Parameters:
data (torch.Tensor) – Flattened data tensor containing all elements. Shape:
(total_elements, ...).offsets (torch.Tensor) – Offset tensor marking tensor boundaries.
offsets[i]is the start index of tensor i, andoffsets[i+1] - offsets[i]is its length. Shape:(num_tensors + 1,).list_ids (torch.Tensor) – Tensor defining the hierarchical position of each tensor. For ldim == 1: shape
(num_tensors, 1)or empty tensor with shape(0, 1). Empty tensor assumes a single, naturally ordered list of tensors. For ldim == 2: shape(num_tensors, 2)where each row is(outer_idx, inner_idx).
- Returns:
jagged_tensor (JaggedTensor) – A
JaggedTensordefined by the data, offsets, and list ids.
- classmethod from_empty(lsizes: Sequence[int] | Sequence[Sequence[int]], rsizes: Sequence[int] | None = None, *, device: device | str | None = None, dtype: dtype | None = None, requires_grad: bool = False, pin_memory: bool = False) JaggedTensor[source]
Create a
JaggedTensorwith uninitialized data.Similar to
torch.empty(), creates aJaggedTensorwith allocated but uninitialized memory, which is faster than initializing values when they will be immediately overwritten.Example:
… code-block:: python
jt = jempty([2, 3, 4], rsizes=[5]) print(jt) # Output: A JaggedTensor containing tensors [of shapes (2, 5), (3, 5), (4, 5)] with uninitialized values.
jt = jempty([[2, 3], [4]], rsizes=[5, 6]) print(jt) # Output: A JaggedTensor containing tensors [of shapes (2, 5, 6), (3, 5, 6), (4, 5, 6)] with uninitialized values.
- Parameters:
lsizes (Sequence[int] | Sequence[Sequence[int]]) – Sizes for the jagged dimensions. Can be a sequence of integers for simple jagged structure, or nested sequences for multi-level jagged structure.
rsizes (Sequence[int] | None) – Sizes for the regular (trailing) dimensions. Defaults to
Nonei.e. scalar elements.device (torch.device | str | None) – Device to create the tensor on. Defaults to
Nonei.e."cpu".dtype (torch.dtype | None) – Data type for the tensor elements. Defaults to
Nonei.e.torch.float32.requires_grad (bool) – Whether to track gradients. Defaults to
False.pin_memory (bool) – Whether to use pinned memory. Defaults to
False.
- Returns:
JaggedTensor – A new
JaggedTensorwith uninitialized data.
- classmethod from_list_of_lists_of_tensors(tensors: Sequence[Sequence[Tensor]]) JaggedTensor[source]
Create a
JaggedTensorfrom a nested sequences oftorch.Tensors.Creates a multi-level jagged structure where both outer and inner sequences can have varying lengths.
- Parameters:
tensors (Sequence[Sequence[torch.Tensor]]) – Nested list/tuple of
torch.Tensors.- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensorwith nested jagged structure.
- classmethod from_list_of_tensors(tensors: Sequence[Tensor]) JaggedTensor[source]
Create a
JaggedTensorfrom a sequence of tensors with varying first dimensions.All tensors must have the same shape except for the first dimension, which can vary. e.g.
[tensor_1, tensor_2, ..., tensor_N]where eachtensor_ihas shape(L_i, D_1, D_2, ...)with varyingL_i.- Parameters:
tensors (Sequence[torch.Tensor]) – List or tuple of
torch.Tensorwith compatible shapes.- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensorcontaining the sequence of tensors.
- classmethod from_ones(lsizes: Sequence[int] | Sequence[Sequence[int]], rsizes: Sequence[int] | None = None, *, device: device | str | None = None, dtype: dtype | None = None, requires_grad: bool = False, pin_memory: bool = False) JaggedTensor[source]
Create a
JaggedTensorfilled with ones.Similar to
torch.ones(), creates aJaggedTensorwhere all elements are initialized to the value 1.Example:
… code-block:: python
jt = jones([2, 3, 4], rsizes=[5]) print(jt) # Output: A JaggedTensor containing tensors
[of shapes (2, 5), (3, 5), (4, 5)] filled with ones.
jt = jones([[2, 3], [4]], rsizes=[5, 6]) print(jt) # Output: A JaggedTensor containing tensors
[of shapes (2, 5, 6), (3, 5, 6), (4, 5, 6)] filled with ones.
- Parameters:
lsizes (Sequence[int] | Sequence[Sequence[int]]) – Sizes for the jagged dimensions. Can be a sequence of integers for simple jagged structure, or nested sequences for multi-level jagged structure.
rsizes (Sequence[int] | None) – Sizes for the regular (trailing) dimensions. Defaults to
Nonei.e. (scalar elements).device (torch.device | str | None) – Device to create the tensor on. Defaults to
Nonei.e. (CPU).dtype (torch.dtype | None) – Data type for the tensor elements. Defaults to
Nonei.e. (torch.float32).requires_grad (bool) – Whether to track gradients. Defaults to
False.pin_memory (bool) – Whether to use pinned memory. Defaults to
False.
- Returns:
JaggedTensor – A new
JaggedTensorfilled with ones.
- classmethod from_rand(lsizes: Sequence[int] | Sequence[Sequence[int]], rsizes: Sequence[int] | None = None, *, device: device | str | None = None, dtype: dtype | None = None, requires_grad: bool = False, pin_memory: bool = False) JaggedTensor[source]
Create a
JaggedTensorwith random values from uniform distribution [0, 1).Similar to
torch.rand(), creates aJaggedTensorfilled with random values sampled from a uniform distribution on the interval [0, 1).Example:
… code-block:: python
jt = jrand([2, 3, 4], rsizes=[5]) print(jt) # Output: A JaggedTensor containing tensors
[of shapes (2, 5), (3, 5), (4, 5)] with random values.
jt = jrand([[2, 3], [4]], rsizes=[5, 6]) print(jt) # Output: A JaggedTensor containing tensors
[of shapes (2, 5, 6), (3, 5, 6), (4, 5, 6)] with random values.
- Parameters:
lsizes (Sequence[int] | Sequence[Sequence[int]]) – Sizes for the jagged dimensions. Can be a sequence of integers for simple jagged structure, or nested sequences for multi-level jagged structure.
rsizes (Sequence[int] | None) – Sizes for the regular (trailing) dimensions. Defaults to
Nonei.e. (scalar elements).device (torch.device | str | None) – Device to create the tensor on. Defaults to
Nonei.e."cpu".dtype (torch.dtype | None) – Data type for the tensor elements. Defaults to
Nonei.e.torch.float32.requires_grad (bool) – Whether to track gradients. Defaults to
False.pin_memory (bool) – Whether to use pinned memory. Defaults to
False.
- Returns:
JaggedTensor – A new
JaggedTensorwith random values in [0, 1).
- classmethod from_randn(lsizes: Sequence[int] | Sequence[Sequence[int]], rsizes: Sequence[int] | None = None, *, device: device | str | None = None, dtype: dtype | None = None, requires_grad: bool = False, pin_memory: bool = False) JaggedTensor[source]
Create a
JaggedTensorwith random values from standard normal distribution.Similar to
torch.randn(), creates aJaggedTensorfilled with random values sampled from a standard normal distribution (mean=0, std=1).Example:
… code-block:: python
jt = jrandn([2, 3, 4], rsizes=[5]) print(jt) # Output: A JaggedTensor containing tensors
[of shapes (2, 5), (3, 5), (4, 5)] with normal random values.
jt = jrandn([[2, 3], [4]], rsizes=[5, 6]) print(jt) # Output: A JaggedTensor containing tensors
[of shapes (2, 5, 6), (3, 5, 6), (4, 5, 6)] with normal random values.
- Parameters:
lsizes (Sequence[int] | Sequence[Sequence[int]]) – Sizes for the jagged dimensions. Can be a sequence of integers for simple jagged structure, or nested sequences for multi-level jagged structure.
rsizes (Sequence[int] | None) – Sizes for the regular (trailing) dimensions. Defaults to
Nonei.e. (scalar elements).device (torch.device | str | None) – Device to create the tensor on. Defaults to
Nonei.e."cpu".dtype (torch.dtype | None) – Data type for the tensor elements. Defaults to
Nonei.e.torch.float32.requires_grad (bool) – Whether to track gradients. Defaults to
False.pin_memory (bool) – Whether to use pinned memory. Defaults to
False.
- Returns:
JaggedTensor – A new
JaggedTensorwith normal random values.
- classmethod from_tensor(data: Tensor) JaggedTensor[source]
Create a
JaggedTensorfrom a singletorch.Tensor.- Parameters:
data (torch.Tensor) – The input tensor.
- Returns:
jagged_tensor (JaggedTensor) – A new JaggedTensor wrapping the input tensor.
- classmethod from_zeros(lsizes: Sequence[int] | Sequence[Sequence[int]], rsizes: Sequence[int] | None = None, *, device: device | str | None = None, dtype: dtype | None = None, requires_grad: bool = False, pin_memory: bool = False) JaggedTensor[source]
Create a
JaggedTensorfilled with zeros.Similar to
torch.zeros(), creates aJaggedTensorwhere all elements are initialized to the value 0.Example
jt = jzeros([2, 3, 4], rsizes=[5]) print(jt) # Output: A JaggedTensor containing tensors
[of shapes (2, 5), (3, 5), (4, 5)] filled with zeros
jt = jzeros([[2, 3], [4]], rsizes=[5, 6]) print(jt) # Output: A JaggedTensor containing tensors
[of shapes (2, 5, 6), (3, 5, 6), (4, 5, 6)] filled with zeros
- Parameters:
lsizes (Sequence[int] | Sequence[Sequence[int]]) – Sizes for the jagged dimensions. Can be a sequence of integers for simple jagged structure, or nested sequences for multi-level jagged structure.
rsizes (Sequence[int] | None) – Sizes for the regular (trailing) dimensions. Defaults to
Nonei.e. scalar elements.device (torch.device | str | None) – Device to create the tensor on. Defaults to
Nonei.e."cpu".dtype (torch.dtype | None) – Data type for the tensor elements. Defaults to
Nonei.e.torch.float32.requires_grad (bool) – Whether to track gradients. Defaults to
False.pin_memory (bool) – Whether to use pinned memory. Defaults to
False.
- Returns:
JaggedTensor – A new
JaggedTensorfilled with zeros.
- int() JaggedTensor[source]
Convert elements to int (int32) dtype.
- Returns:
JaggedTensor – A new
JaggedTensorwith int32 dtype.
- property is_cpu: bool
Whether this
JaggedTensoris stored on the CPU.- Returns:
bool –
Trueif on CPU,Falseotherwise.
- property is_cuda: bool
Whether this
JaggedTensoris stored on a CUDA device.- Returns:
bool –
Trueif on CUDA,Falseotherwise.
- jagged_like(data: Tensor) JaggedTensor[source]
Create a new JaggedTensor with the same structure but different data.
The new JaggedTensor will have the same jagged structure
(joffsets, jidx, etc.)as the current one, but with newjdatavalues.- Parameters:
data (torch.Tensor) – New data tensor with compatible shape. Must have the same leading dimension as self.jdata.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensorwith the same structure but new data.
- property jdata: Tensor
Flattened data tensor containing all elements in this JaggedTensor.
For example, if this
JaggedTensorrepresents three tensors of shapes(2, 4),(3, 4), and(1, 4), thenjdatawill have shape(6, 4).- Returns:
torch.Tensor – The data tensor.
- jflatten(dim: int = 0) JaggedTensor[source]
Flatten the jagged dimensions starting from the specified dimension.
Example
# Original jagged tensor with 2 jagged dimensions # representing a tensor of shape [ [ t_00, t_01, … ], [ t_b0, t_b1, … ] ] jt = JaggedTensor.from_list_of_lists_of_tensors(…)
# Flatten starting from dim=0 jt_flat = jt.jflatten(dim=0)
# jt_flat is now a jagged tensor with 1 jagged dimension and represents # [ t_00, t_01, …, t_b0, t_b1, … ]
- Parameters:
dim (int) – The dimension from which to start flattening. Defaults to 0.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensorwith flattened jagged structure.
- property jidx: Tensor
Indices for each element in the jagged structure. This maps each element in the
jdatatensor to its corresponding position in the jagged layout.Example:
# For a JaggedTensor representing three tensors of shapes (2, 4), (3, 4), and (1, 4), # the ``jidx`` tensor would be: ``tensor([0, 1, 0, 1, 2, 0])``. jt = JaggedTensor.from_list_of_tensors([torch.randn(2, 4), torch.randn(3, 4), torch.randn(1, 4)]) print(jt.jidx) # Output: tensor([0, 1, 0, 1, 2, 0])
- Returns:
torch.Tensor – The jagged indices tensor.
- property jlidx: Tensor
List indices for nested jagged structures. This is a
torch.Tensorthat maps each element in thejdatatensor to its corresponding list in the jagged layout.Example:
# For a JaggedTensor representing two lists of tensors: # List 0: tensors of shapes (2, 3) and (1, 3) # List 1: tensor of shape (4, 3) # the jlidx tensor would be: tensor([0, 0], [0, 1], [1, 0]). jt = JaggedTensor.from_list_of_lists_of_tensors([[torch.randn(2, 3), torch.randn(1, 3)], [torch.randn(4, 3)]]) print(jt.jlidx) # Output: tensor([[0, 0], [0, 1], [1, 0]])
- Returns:
torch.Tensor – The jagged list indices tensor.
- jmax(dim: int = 0, keepdim: bool = False) list[JaggedTensor][source]
Compute the maximum along a dimension of each tensor in the jagged structure.
Returns both the maximum values and the indices where they occur.
Example
# Create a jagged tensor from a list of tensors of each of shape (L_i, D) jt = JaggedTensor.from_list_of_lists_of_tensors([t1, t2, t3])
# Compute the maximum along the jagged dimension (dim=0) values, indices = jt.jmax(dim=0)
# values is now a jagged tensor containing the maximum values from each tensor # along dim=0 # this is equivalent to (but faster than): # values = JaggedTensor.from_list_of_lists_of_tensors([torch.max(t, dim=0).values for t in [t1, t2, t3]]) # indices = JaggedTensor.from_list_of_lists_of_tensors([torch.max(t, dim=0).indices for t in [t1, t2, t3]])
- Parameters:
dim (int) – The dimension along which to compute max for each tensor. Defaults to 0.
keepdim (bool) – Whether to keep the reduced dimension. Defaults to False.
- Returns:
values (JaggedTensor) – A
JaggedTensorcontaining the maximum values.indices (JaggedTensor) – A
JaggedTensorcontaining the indices of the maximum values.
- jmin(dim: int = 0, keepdim: bool = False) list[JaggedTensor][source]
Compute the minimum along a dimension of each tensor in the jagged structure.
Returns both the minimum values and the indices where they occur.
Example:
# Create a jagged tensor from a list of tensors of each of shape (L_i, D) jt = JaggedTensor.from_list_of_lists_of_tensors([t1, t2, t3]) # Compute the minimum along the jagged dimension (dim=0) values, indices = jt.jmin(dim=0) # values is now a jagged tensor containing the minimum values from each tensor # along dim=0 # this is equivalent to (but faster than): # values = JaggedTensor.from_list_of_lists_of_tensors([torch.min(t, dim=0).values for t in [t1, t2, t3]]) # indices = JaggedTensor.from_list_of_lists_of_tensors([torch.min(t, dim=0).indices for t in [t1, t2, t3]])
- Parameters:
values (JaggedTensor) – A
JaggedTensorcontaining the minimum values.indices (JaggedTensor) – A
JaggedTensorcontaining the indices of the minimum values.
- Returns:
list[JaggedTensor] – A list containing [values, indices] as JaggedTensors.
- property joffsets: Tensor
Offsets marking boundaries between tensors.
Example:
# For a JaggedTensor representing three tensors of shapes (2, 4), (3, 4), and (1, 4), # the ``joffsets`` tensor would be: ``tensor([0, 2, 5, 6])``. jt = JaggedTensor.from_list_of_tensors([torch.randn(2, 4), torch.randn(3, 4), torch.randn(1, 4)]) print(jt.joffsets) # Output: tensor([0, 2, 5, 6]) # For a JaggedTensor representing two lists of tensors: # List 0: tensors of shapes (2, 3) and (1, 3) # List 1: tensor of shape (4, 3) # the joffsets tensor would be: tensor([0, 2, 3, 7]). jt_ll = JaggedTensor.from_list_of_lists_of_tensors([[torch.randn(2, 3), torch.randn(1, 3)], [torch.randn(4, 3)]]) print(jt_ll.joffsets) # Output: tensor([0, 2, 3, 7])
- Returns:
torch.Tensor – The jagged offsets tensor.
- jreshape(lshape: Sequence[int] | Sequence[Sequence[int]]) JaggedTensor[source]
Reshape the jagged dimensions to new sizes.
- Parameters:
lshape (Sequence[int] | Sequence[Sequence[int]]) – New shape(s) for jagged dimensions. Can be a single sequence of sizes or nested sequences for multi-level structure.
- Returns:
JaggedTensor – A new JaggedTensor with reshaped jagged structure.
- jreshape_as(other: JaggedTensor | Tensor) JaggedTensor[source]
Reshape the jagged structure to match another JaggedTensor or Tensor.
- Parameters:
other (JaggedTensor | torch.Tensor) – The target structure to match.
- Returns:
JaggedTensor – A new JaggedTensor with structure matching other.
- jsqueeze(dim: int | None = None) JaggedTensor[source]
Remove singleton dimensions from the jagged structure.
- Parameters:
dim (int | None) – Specific dimension to squeeze, or None to squeeze all singleton dimensions. Defaults to None.
- Returns:
JaggedTensor – A new JaggedTensor with singleton dimensions removed.
- jsum(dim: int = 0, keepdim: bool = False) JaggedTensor[source]
Sum along a jagged dimension.
- Parameters:
dim (int) – The jagged dimension along which to sum. Defaults to 0.
keepdim (bool) – Whether to keep the reduced dimension. Defaults to False.
- Returns:
JaggedTensor – A new JaggedTensor with values summed along the specified dimension.
- property ldim: int
Dimensionality of the jagged (leading) structure. i.e. the number of jagged levels.
If the
JaggedTensorrepresents a simple jagged structure (a single list of tensors), thenldimwill be1. For nested jagged structures (lists of lists of tensors),ldimwill be greater than1.- Returns:
int – The dimensionality of the jagged structure.
- long() JaggedTensor[source]
Convert elements to long (int64) dtype.
- Returns:
JaggedTensor – A new JaggedTensor with int64 dtype.
- property lshape: list[int] | list[list[int]]
List structure shape(s) of the jagged dimensions.
Example:
# For a JaggedTensor representing three tensors of shapes (2, 4), (3, 4), and (1, 4), # the ``lshape`` will be: ``[2, 3, 4]`` (three tensors in the jagged structure). jt = JaggedTensor.from_list_of_tensors([torch.randn(2, 4), torch.randn(3, 4), torch.randn(1, 4)]) print(jt.lshape) # Output: [2, 3, 1] # For a JaggedTensor representing two lists of tensors: # List 0: tensors of shapes (2, 3) and (1, 3) # List 1: tensor of shape (4, 3) # the ``lshape`` will be: ``[[2, 1], [4]]``. jt_ll = JaggedTensor.from_list_of_lists_of_tensors([[torch.randn(2, 3), torch.randn(1, 3)], [torch.randn(4, 3)]]) print(jt_ll.lshape) # Output: [[2, 1], [4]]
- Returns:
list[int] | list[list[int]] – The jagged structure shapes.
- property num_tensors: int
Return the total number of tensors in the jagged sequence.
- Returns:
int – Number of tensors in this
JaggedTensor.
- property requires_grad: bool
Whether this
JaggedTensorrequires gradient computation.- Returns:
bool –
Trueif gradients are tracked, False otherwise.
- requires_grad_(requires_grad: bool) JaggedTensor[source]
Set the requires_grad attribute in-place.
- Parameters:
requires_grad (bool) – Whether to track gradients for this tensor.
- Returns:
JaggedTensor – The modified JaggedTensor (self).
- rmask(mask: Tensor) JaggedTensor[source]
Apply a mask to filter elements along the regular (non-jagged) dimension.
- Parameters:
mask (torch.Tensor) – Boolean mask tensor to apply. Shape must be compatible with the regular dimensions.
- Returns:
JaggedTensor – A new
JaggedTensorwith masked elements.
- round(decimals: int = 0) JaggedTensor[source]
Round elements to the specified number of decimals.
- Parameters:
decimals (int) – Number of decimal places to round to. Defaults to 0.
- Returns:
JaggedTensor – A new
JaggedTensorwith rounded values.
- round_(decimals: int = 0) JaggedTensor[source]
Round elements to the specified number of decimals in-place.
- Parameters:
decimals (int) – Number of decimal places to round to. Defaults to 0.
- Returns:
JaggedTensor – The modified
JaggedTensor(self).
- property rshape: tuple[int, ...]
Return the shape of the
jdatatensor.Note
rshapestands for “raw shape” and represents the full shape of the underlying data tensor, including both jagged and regular dimensions.- Returns:
tuple[int, …] – Shape of the underlying data tensor.
- sqrt() JaggedTensor[source]
Compute the square root element-wise.
- Returns:
JaggedTensor – A new
JaggedTensorwith square root applied.
- sqrt_() JaggedTensor[source]
Compute the square root element-wise in-place.
- Returns:
JaggedTensor – The modified
JaggedTensor(self).
- to(device_or_dtype: device | str | dtype) JaggedTensor[source]
Move the JaggedTensor to a device or convert to a dtype.
- Parameters:
device_or_dtype (torch.device | str | torch.dtype) – Target
torch.deviceortorch.dtype. Can be a device (“cpu”, “cuda”), or a dtype (torch.float32, etc.).- Returns:
JaggedTensor – A new
JaggedTensoron the specified device or with specified dtype.
- type(dtype: dtype) JaggedTensor[source]
Convert the
JaggedTensorto a specific dtype.- Parameters:
dtype (torch.dtype) – Target data type (e.g.
torch.float32,torch.int64).- Returns:
JaggedTensor – A new
JaggedTensorwith the specified dtype.
- type_as(other: JaggedTensor | Tensor) JaggedTensor[source]
Convert the
JaggedTensorto match the dtype of another tensor.- Parameters:
other (JaggedTensor | torch.Tensor) – Reference
torch.TensororJaggedTensorwhose dtype to match.- Returns:
JaggedTensor – A new
JaggedTensorwith dtype matching other.
- unbind() list[Tensor] | list[list[Tensor]][source]
Unbind the
JaggedTensorinto its constituent tensors.- Returns:
list[torch.Tensor] | list[list[torch.Tensor]] – A list of
torch.Tensor(for simple jagged structure) or a list of lists oftorch.Tensor(for nested structure).