viva_tensor/named
Named Tensor - Tensors with semantic axis names
Wrap tensors with named axes for clearer, safer operations. Instead of sum(t, axis: 0), write sum(t, along: Batch)
Types
Tensor with named axes
pub type NamedTensor {
NamedTensor(data: tensor.Tensor, axes: List(axis.AxisSpec))
}
Constructors
-
NamedTensor(data: tensor.Tensor, axes: List(axis.AxisSpec))Arguments
- data
-
Underlying data tensor
- axes
-
Axis specifications (names + sizes, in order)
Error types for named tensor operations
pub type NamedTensorError {
AxisNotFound(name: axis.Axis)
DuplicateAxis(name: axis.Axis)
AxisMismatch(expected: axis.Axis, got: axis.Axis)
SizeMismatch(axis: axis.Axis, expected: Int, got: Int)
BroadcastErr(reason: String)
TensorErr(error.TensorError)
InvalidOp(reason: String)
}
Constructors
-
AxisNotFound(name: axis.Axis)Axis not found
-
DuplicateAxis(name: axis.Axis)Duplicate axis name
-
Axis mismatch in operation
-
SizeMismatch(axis: axis.Axis, expected: Int, got: Int)Size mismatch for same axis
-
BroadcastErr(reason: String)Cannot broadcast axes
-
TensorErr(error.TensorError)Underlying tensor error
-
InvalidOp(reason: String)Invalid operation
Values
pub fn add(
a: NamedTensor,
b: NamedTensor,
) -> Result(NamedTensor, NamedTensorError)
Element-wise add (same axes required)
pub fn axis_size(
t: NamedTensor,
name: axis.Axis,
) -> Result(Int, NamedTensorError)
Get axis size by name
pub fn find_axis(
t: NamedTensor,
name: axis.Axis,
) -> Result(Int, NamedTensorError)
Find axis index by name
pub fn from_tensor(t: tensor.Tensor) -> NamedTensor
Create from tensor with inferred anonymous axes
pub fn mean_along(
t: NamedTensor,
axis_name: axis.Axis,
) -> Result(NamedTensor, NamedTensorError)
Mean along named axis
pub fn mul(
a: NamedTensor,
b: NamedTensor,
) -> Result(NamedTensor, NamedTensorError)
Element-wise mul (same axes required)
pub fn new(
data: tensor.Tensor,
axes: List(axis.AxisSpec),
) -> Result(NamedTensor, NamedTensorError)
Create named tensor from data and axis specs
pub fn randn(
axes: List(axis.AxisSpec),
mean: Float,
std: Float,
) -> NamedTensor
Create named tensor with normal distribution
pub fn random(axes: List(axis.AxisSpec)) -> NamedTensor
Create named tensor with random values [0, 1)
pub fn rename_axis(
t: NamedTensor,
from: axis.Axis,
to: axis.Axis,
) -> Result(NamedTensor, NamedTensorError)
Rename an axis
pub fn squeeze(
t: NamedTensor,
name: axis.Axis,
) -> Result(NamedTensor, NamedTensorError)
Remove axis of size 1 by name
pub fn sum_along(
t: NamedTensor,
axis_name: axis.Axis,
) -> Result(NamedTensor, NamedTensorError)
Sum along named axis
pub fn to_tensor(t: NamedTensor) -> tensor.Tensor
Convert to plain tensor (drop names)
pub fn unsqueeze(
t: NamedTensor,
name: axis.Axis,
position: Int,
) -> NamedTensor
Add a new axis of size 1