viva_tensor/tensor
Tensor - N-dimensional arrays for numerical computing
Design: NumPy-inspired with strides for zero-copy views. Uses Erlang :array for O(1) access + strides for efficient transpose/reshape.
Types
Opaque type for Erlang :array
pub type ErlangArray
Tensor with NumPy-style strides for zero-copy views
- storage: contiguous data buffer (Erlang array for O(1) access)
- shape: dimensions [d0, d1, …, dn]
- strides: bytes to skip for each dimension [s0, s1, …, sn]
- offset: starting position in storage (for views/slices)
pub type Tensor {
Tensor(data: List(Float), shape: List(Int))
StridedTensor(
storage: ErlangArray,
shape: List(Int),
strides: List(Int),
offset: Int,
)
}
Constructors
-
Tensor(data: List(Float), shape: List(Int)) -
StridedTensor( storage: ErlangArray, shape: List(Int), strides: List(Int), offset: Int, )
Tensor operation errors
pub type TensorError {
ShapeMismatch(expected: List(Int), got: List(Int))
InvalidShape(reason: String)
DimensionError(reason: String)
BroadcastError(a: List(Int), b: List(Int))
}
Constructors
-
ShapeMismatch(expected: List(Int), got: List(Int)) -
InvalidShape(reason: String) -
DimensionError(reason: String) -
BroadcastError(a: List(Int), b: List(Int))
Values
pub fn add_broadcast(
a: Tensor,
b: Tensor,
) -> Result(Tensor, TensorError)
Element-wise addition with broadcasting
pub fn broadcast_shape(
a: List(Int),
b: List(Int),
) -> Result(List(Int), TensorError)
Compute broadcast shape
pub fn broadcast_to(
t: Tensor,
target_shape: List(Int),
) -> Result(Tensor, TensorError)
Broadcast tensor to target shape
pub fn can_broadcast(a: List(Int), b: List(Int)) -> Bool
Check if two shapes can be broadcast together
pub fn concat_axis(
tensors: List(Tensor),
axis: Int,
) -> Result(Tensor, TensorError)
Concatenate tensors along a specific axis For [2,3] and [2,3] tensors: concat_axis([a, b], 0) -> [4,3] For [2,3] and [2,3] tensors: concat_axis([a, b], 1) -> [2,6]
pub fn from_list2d(
rows: List(List(Float)),
) -> Result(Tensor, TensorError)
Create 2D tensor (matrix) from list of lists
pub fn get(t: Tensor, index: Int) -> Result(Float, TensorError)
Access element by linear index
pub fn get2d(
t: Tensor,
row: Int,
col: Int,
) -> Result(Float, TensorError)
Access 2D element
pub fn get2d_fast(
t: Tensor,
row: Int,
col: Int,
) -> Result(Float, TensorError)
Get 2D element with O(1) access
pub fn get_col(
t: Tensor,
col_idx: Int,
) -> Result(Tensor, TensorError)
Get matrix column as vector
pub fn get_fast(
t: Tensor,
index: Int,
) -> Result(Float, TensorError)
Get element with O(1) access for StridedTensor
pub fn get_row(
t: Tensor,
row_idx: Int,
) -> Result(Tensor, TensorError)
Get matrix row as vector
pub fn matmul(
a: Tensor,
b: Tensor,
) -> Result(Tensor, TensorError)
Matrix-matrix multiplication: [m, n] @ [n, p] -> [m, p]
pub fn matmul_vec(
mat: Tensor,
vec: Tensor,
) -> Result(Tensor, TensorError)
Matrix-vector multiplication: [m, n] @ [n] -> [m]
pub fn matrix(
rows: Int,
cols: Int,
data: List(Float),
) -> Result(Tensor, TensorError)
Create matrix (2D tensor) with explicit dimensions
pub fn mean_axis(
t: Tensor,
axis_idx: Int,
) -> Result(Tensor, TensorError)
Mean along a specific axis
pub fn mul(a: Tensor, b: Tensor) -> Result(Tensor, TensorError)
Element-wise multiplication (Hadamard)
pub fn mul_broadcast(
a: Tensor,
b: Tensor,
) -> Result(Tensor, TensorError)
Element-wise multiplication with broadcasting
pub fn outer(a: Tensor, b: Tensor) -> Result(Tensor, TensorError)
Outer product: [m] @ [n] -> [m, n]
pub fn random_normal(
shape: List(Int),
mean_val: Float,
std_val: Float,
) -> Tensor
Tensor with normal random values (approx via Box-Muller)
pub fn random_uniform(shape: List(Int)) -> Tensor
Tensor with uniform random values [0, 1)
pub fn reshape(
t: Tensor,
new_shape: List(Int),
) -> Result(Tensor, TensorError)
Reshape tensor
pub fn slice(
t: Tensor,
start: List(Int),
lengths: List(Int),
) -> Result(Tensor, TensorError)
Slice tensor: extract sub-tensor from start to start+lengths slice(t, [1], [3]) extracts elements at indices 1, 2, 3
pub fn squeeze_axis(
t: Tensor,
axis: Int,
) -> Result(Tensor, TensorError)
Remove dimension at specific axis if it’s 1
pub fn stack(
tensors: List(Tensor),
axis: Int,
) -> Result(Tensor, TensorError)
Stack tensors along a new axis For [3] and [3] tensors: stack([a, b], 0) -> [2, 3] For [3] and [3] tensors: stack([a, b], 1) -> [3, 2]
pub fn sum_axis(
t: Tensor,
axis_idx: Int,
) -> Result(Tensor, TensorError)
Sum along a specific axis For a [2, 3] tensor, sum_axis(, 0) gives [3], sum_axis(, 1) gives [2]
pub fn to_contiguous(t: Tensor) -> Tensor
Convert strided tensor back to regular (materializes the view)
pub fn to_list2d(
t: Tensor,
) -> Result(List(List(Float)), TensorError)
Convert matrix to list of lists
pub fn to_strided(t: Tensor) -> Tensor
Convert regular tensor to strided (O(n) once, then O(1) access)
pub fn transpose_strided(
t: Tensor,
) -> Result(Tensor, TensorError)
ZERO-COPY TRANSPOSE - just swap strides and shape!
pub fn xavier_init(fan_in: Int, fan_out: Int) -> Tensor
Xavier initialization for weights