viva_tensor/tensor
Tensor - N-dimensional arrays for numerical computing
Design: NumPy-inspired with strides for zero-copy views. Uses Erlang :array for O(1) access + strides for efficient transpose/reshape.
Types
Conv2D configuration
pub type Conv2dConfig {
Conv2dConfig(
kernel_h: Int,
kernel_w: Int,
stride_h: Int,
stride_w: Int,
padding_h: Int,
padding_w: Int,
)
}
Constructors
-
Conv2dConfig( kernel_h: Int, kernel_w: Int, stride_h: Int, stride_w: Int, padding_h: Int, padding_w: Int, )
Opaque type for Erlang :array
pub type ErlangArray
Tensor with NumPy-style strides for zero-copy views
- storage: contiguous data buffer (Erlang array for O(1) access)
- shape: dimensions [d0, d1, …, dn]
- strides: bytes to skip for each dimension [s0, s1, …, sn]
- offset: starting position in storage (for views/slices)
pub type Tensor {
Tensor(data: List(Float), shape: List(Int))
StridedTensor(
storage: ErlangArray,
shape: List(Int),
strides: List(Int),
offset: Int,
)
}
Constructors
-
Tensor(data: List(Float), shape: List(Int)) -
StridedTensor( storage: ErlangArray, shape: List(Int), strides: List(Int), offset: Int, )
Tensor operation errors
pub type TensorError {
ShapeMismatch(expected: List(Int), got: List(Int))
InvalidShape(reason: String)
DimensionError(reason: String)
BroadcastError(a: List(Int), b: List(Int))
}
Constructors
-
ShapeMismatch(expected: List(Int), got: List(Int)) -
InvalidShape(reason: String) -
DimensionError(reason: String) -
BroadcastError(a: List(Int), b: List(Int))
Values
pub fn add_broadcast(
a: Tensor,
b: Tensor,
) -> Result(Tensor, TensorError)
Element-wise addition with broadcasting
pub fn avg_pool2d(
input: Tensor,
pool_h: Int,
pool_w: Int,
stride_h: Int,
stride_w: Int,
) -> Result(Tensor, TensorError)
Average pooling 2D Average pooling 2D - OPTIMIZED with O(1) array access
pub fn broadcast_shape(
a: List(Int),
b: List(Int),
) -> Result(List(Int), TensorError)
Compute broadcast shape
pub fn broadcast_to(
t: Tensor,
target_shape: List(Int),
) -> Result(Tensor, TensorError)
Broadcast tensor to target shape
pub fn can_broadcast(a: List(Int), b: List(Int)) -> Bool
Check if two shapes can be broadcast together
pub fn concat_axis(
tensors: List(Tensor),
axis: Int,
) -> Result(Tensor, TensorError)
Concatenate tensors along a specific axis For [2,3] and [2,3] tensors: concat_axis([a, b], 0) -> [4,3] For [2,3] and [2,3] tensors: concat_axis([a, b], 1) -> [2,6]
pub fn conv2d(
input: Tensor,
kernel: Tensor,
config: Conv2dConfig,
) -> Result(Tensor, TensorError)
Extract a patch from 2D tensor at position (row, col) 2D Convolution using optimized O(1) array access Input: [H, W] or [C, H, W] or [N, C, H, W] Kernel: [K_out, K_in, KH, KW] or [KH, KW] for single channel Output: [H_out, W_out] or [N, K_out, H_out, W_out]
pub fn conv2d_config() -> Conv2dConfig
Default conv2d config (3x3 kernel, stride 1, no padding)
pub fn conv2d_same(kernel_h: Int, kernel_w: Int) -> Conv2dConfig
Conv2d config with “same” padding (output same size as input)
pub fn from_list2d(
rows: List(List(Float)),
) -> Result(Tensor, TensorError)
Create 2D tensor (matrix) from list of lists
pub fn get(t: Tensor, index: Int) -> Result(Float, TensorError)
Access element by linear index
pub fn get2d(
t: Tensor,
row: Int,
col: Int,
) -> Result(Float, TensorError)
Access 2D element
pub fn get2d_fast(
t: Tensor,
row: Int,
col: Int,
) -> Result(Float, TensorError)
Get 2D element with O(1) access
pub fn get_col(
t: Tensor,
col_idx: Int,
) -> Result(Tensor, TensorError)
Get matrix column as vector
pub fn get_fast(
t: Tensor,
index: Int,
) -> Result(Float, TensorError)
Get element with O(1) access for StridedTensor
pub fn get_row(
t: Tensor,
row_idx: Int,
) -> Result(Tensor, TensorError)
Get matrix row as vector
pub fn global_avg_pool2d(
input: Tensor,
) -> Result(Tensor, TensorError)
Global average pooling - reduces spatial dimensions to 1x1 Input: [N, C, H, W] -> Output: [N, C, 1, 1]
pub fn matmul(
a: Tensor,
b: Tensor,
) -> Result(Tensor, TensorError)
Matrix-matrix multiplication: [m, n] @ [n, p] -> [m, p]
pub fn matmul_vec(
mat: Tensor,
vec: Tensor,
) -> Result(Tensor, TensorError)
Matrix-vector multiplication: [m, n] @ [n] -> [m]
pub fn matrix(
rows: Int,
cols: Int,
data: List(Float),
) -> Result(Tensor, TensorError)
Create matrix (2D tensor) with explicit dimensions
pub fn max_pool2d(
input: Tensor,
pool_h: Int,
pool_w: Int,
stride_h: Int,
stride_w: Int,
) -> Result(Tensor, TensorError)
Max pooling 2D - OPTIMIZED with O(1) array access Input: [H, W] or [N, C, H, W] Output: [H_out, W_out] or [N, C, H_out, W_out]
pub fn mean_axis(
t: Tensor,
axis_idx: Int,
) -> Result(Tensor, TensorError)
Mean along a specific axis
pub fn mul(a: Tensor, b: Tensor) -> Result(Tensor, TensorError)
Element-wise multiplication (Hadamard)
pub fn mul_broadcast(
a: Tensor,
b: Tensor,
) -> Result(Tensor, TensorError)
Element-wise multiplication with broadcasting
pub fn outer(a: Tensor, b: Tensor) -> Result(Tensor, TensorError)
Outer product: [m] @ [n] -> [m, n]
pub fn pad2d(
t: Tensor,
pad_h: Int,
pad_w: Int,
) -> Result(Tensor, TensorError)
Pad a 2D tensor with zeros Input: [H, W], Output: [H + 2pad_h, W + 2pad_w]
pub fn pad4d(
t: Tensor,
pad_h: Int,
pad_w: Int,
) -> Result(Tensor, TensorError)
Pad a 4D tensor (batch) with zeros Input: [N, C, H, W], Output: [N, C, H + 2pad_h, W + 2pad_w]
pub fn random_normal(
shape: List(Int),
mean_val: Float,
std_val: Float,
) -> Tensor
Tensor with normal random values (approx via Box-Muller)
pub fn random_uniform(shape: List(Int)) -> Tensor
Tensor with uniform random values [0, 1)
pub fn reshape(
t: Tensor,
new_shape: List(Int),
) -> Result(Tensor, TensorError)
Reshape tensor
pub fn slice(
t: Tensor,
start: List(Int),
lengths: List(Int),
) -> Result(Tensor, TensorError)
Slice tensor: extract sub-tensor from start to start+lengths slice(t, [1], [3]) extracts elements at indices 1, 2, 3
pub fn squeeze_axis(
t: Tensor,
axis: Int,
) -> Result(Tensor, TensorError)
Remove dimension at specific axis if it’s 1
pub fn stack(
tensors: List(Tensor),
axis: Int,
) -> Result(Tensor, TensorError)
Stack tensors along a new axis For [3] and [3] tensors: stack([a, b], 0) -> [2, 3] For [3] and [3] tensors: stack([a, b], 1) -> [3, 2]
pub fn sum_axis(
t: Tensor,
axis_idx: Int,
) -> Result(Tensor, TensorError)
Sum along a specific axis For a [2, 3] tensor, sum_axis(, 0) gives [3], sum_axis(, 1) gives [2]
pub fn to_contiguous(t: Tensor) -> Tensor
Convert strided tensor back to regular (materializes the view)
pub fn to_list2d(
t: Tensor,
) -> Result(List(List(Float)), TensorError)
Convert matrix to list of lists
pub fn to_strided(t: Tensor) -> Tensor
Convert regular tensor to strided (O(n) once, then O(1) access)
pub fn transpose_strided(
t: Tensor,
) -> Result(Tensor, TensorError)
ZERO-COPY TRANSPOSE - just swap strides and shape!
pub fn xavier_init(fan_in: Int, fan_out: Int) -> Tensor
Xavier initialization for weights