viva_tensor/named

Named Tensor - Tensors with semantic axis names

Wrap tensors with named axes for clearer, safer operations. Instead of sum(t, axis: 0), write sum(t, along: Batch)

Types

Tensor with named axes

pub type NamedTensor {
  NamedTensor(data: tensor.Tensor, axes: List(axis.AxisSpec))
}

Constructors

  • NamedTensor(data: tensor.Tensor, axes: List(axis.AxisSpec))

    Arguments

    data

    Underlying data tensor

    axes

    Axis specifications (names + sizes, in order)

Error types for named tensor operations

pub type NamedTensorError {
  AxisNotFound(name: axis.Axis)
  DuplicateAxis(name: axis.Axis)
  AxisMismatch(expected: axis.Axis, got: axis.Axis)
  SizeMismatch(axis: axis.Axis, expected: Int, got: Int)
  BroadcastErr(reason: String)
  TensorErr(error.TensorError)
  InvalidOp(reason: String)
}

Constructors

  • AxisNotFound(name: axis.Axis)

    Axis not found

  • DuplicateAxis(name: axis.Axis)

    Duplicate axis name

  • AxisMismatch(expected: axis.Axis, got: axis.Axis)

    Axis mismatch in operation

  • SizeMismatch(axis: axis.Axis, expected: Int, got: Int)

    Size mismatch for same axis

  • BroadcastErr(reason: String)

    Cannot broadcast axes

  • TensorErr(error.TensorError)

    Underlying tensor error

  • InvalidOp(reason: String)

    Invalid operation

Values

pub fn add(
  a: NamedTensor,
  b: NamedTensor,
) -> Result(NamedTensor, NamedTensorError)

Element-wise add (same axes required)

pub fn axis_names(t: NamedTensor) -> List(axis.Axis)

Get all axis names

pub fn axis_size(
  t: NamedTensor,
  name: axis.Axis,
) -> Result(Int, NamedTensorError)

Get axis size by name

pub fn describe(t: NamedTensor) -> String

Pretty print tensor info

pub fn find_axis(
  t: NamedTensor,
  name: axis.Axis,
) -> Result(Int, NamedTensorError)

Find axis index by name

pub fn from_tensor(t: tensor.Tensor) -> NamedTensor

Create from tensor with inferred anonymous axes

pub fn has_axis(t: NamedTensor, name: axis.Axis) -> Bool

Check if tensor has axis

pub fn map(t: NamedTensor, f: fn(Float) -> Float) -> NamedTensor

Map function over elements

pub fn mean_along(
  t: NamedTensor,
  axis_name: axis.Axis,
) -> Result(NamedTensor, NamedTensorError)

Mean along named axis

pub fn mul(
  a: NamedTensor,
  b: NamedTensor,
) -> Result(NamedTensor, NamedTensorError)

Element-wise mul (same axes required)

pub fn new(
  data: tensor.Tensor,
  axes: List(axis.AxisSpec),
) -> Result(NamedTensor, NamedTensorError)

Create named tensor from data and axis specs

pub fn ones(axes: List(axis.AxisSpec)) -> NamedTensor

Create named tensor of ones

pub fn randn(
  axes: List(axis.AxisSpec),
  mean: Float,
  std: Float,
) -> NamedTensor

Create named tensor with normal distribution

pub fn random(axes: List(axis.AxisSpec)) -> NamedTensor

Create named tensor with random values [0, 1)

pub fn rank(t: NamedTensor) -> Int

Get rank (number of dimensions)

pub fn rename_axis(
  t: NamedTensor,
  from: axis.Axis,
  to: axis.Axis,
) -> Result(NamedTensor, NamedTensorError)

Rename an axis

pub fn scale(t: NamedTensor, s: Float) -> NamedTensor

Scale by constant

pub fn shape(t: NamedTensor) -> List(Int)

Get shape as list

pub fn size(t: NamedTensor) -> Int

Total number of elements

pub fn squeeze(
  t: NamedTensor,
  name: axis.Axis,
) -> Result(NamedTensor, NamedTensorError)

Remove axis of size 1 by name

pub fn sum_along(
  t: NamedTensor,
  axis_name: axis.Axis,
) -> Result(NamedTensor, NamedTensorError)

Sum along named axis

pub fn to_tensor(t: NamedTensor) -> tensor.Tensor

Convert to plain tensor (drop names)

pub fn unsqueeze(
  t: NamedTensor,
  name: axis.Axis,
  position: Int,
) -> NamedTensor

Add a new axis of size 1

pub fn zeros(axes: List(axis.AxisSpec)) -> NamedTensor

Create named tensor of zeros

Search Document