This document provides an overview of the Rust Tensor Library's architecture, explaining the key components, their relationships, and design decisions.
- High-Level Architecture
- Core Components
- Backend System
- Automatic Differentiation
- Hooks System
- Operations System
- Operator Overloading
- Optimizers
- Feature Flags
- Design Decisions
The RustTensor Library is designed around a modular architecture with several key components:
graph TD
User[User Code] --> Tensor[Tensor<B: Backend>]
Tensor --> Backend[Backend Trait]
Tensor --> Ops[Operations]
Tensor --> Graph[Computation Graph]
Backend --> CPU[CPU Backend]
Backend --> CUDA[CUDA Backend]
Ops --> CPU_Ops[CPU Implementations]
Ops --> CUDA_Ops[CUDA Implementations]
Graph --> Autograd[Automatic Differentiation]
Autograd --> Optim[Optimizers]
classDef core fill:#f9f,stroke:#333,stroke-width:2px;
classDef backend fill:#bbf,stroke:#333,stroke-width:1px;
classDef ops fill:#bfb,stroke:#333,stroke-width:1px;
classDef user fill:#fbb,stroke:#333,stroke-width:1px;
class Tensor,Graph,Autograd core;
class Backend,CPU,CUDA backend;
class Ops,CPU_Ops,CUDA_Ops ops;
class User,Optim user;
The Tensor<B: Backend> struct is the central type in the library:
- Generic over Backend: Uses a type parameter
Bthat implements theBackendtrait - Data Storage: Contains data in backend-specific storage format
- Gradient Tracking: Optional gradient storage for automatic differentiation
- Computation Graph: Tracks operations for backward pass
- Device Information: Tracks where the tensor data resides (CPU/CUDA)
pub struct Tensor<B: Backend> {
data: Rc<RefCell<TensorData<B>>>,
_marker: PhantomData<B>,
}
pub struct TensorData<B: Backend> {
pub id: usize,
pub data: B::Storage,
pub grad: Option<B::Storage>,
pub requires_grad: bool,
pub op: Option<Op<B>>,
pub device: Device,
pub hooks: Vec<Box<dyn Hook<B>>>,
}For convenience, the library provides type aliases for common backend configurations:
pub type CpuTensor = Tensor<CpuBackend>;
pub type CudaTensor = Tensor<CudaBackend>;The Backend trait defines the interface that all backends must implement:
pub trait Backend: 'static + Sized {
type Storage: 'static;
// Creation methods
fn zeros(shape: &[usize]) -> Result<Self::Storage, Error>;
fn ones(shape: &[usize]) -> Result<Self::Storage, Error>;
fn from_array(array: Array) -> Result<Self::Storage, Error>;
// Data access methods
fn shape(tensor: &Self::Storage) -> &[usize];
fn copy_to_host(tensor: &Self::Storage) -> Result<Vec<f32>, Error>;
// Operations
fn add(a: &Self::Storage, b: &Self::Storage) -> Result<Self::Storage, Error>;
fn sub(a: &Self::Storage, b: &Self::Storage) -> Result<Self::Storage, Error>;
fn mul(a: &Self::Storage, b: &Self::Storage) -> Result<Self::Storage, Error>;
// ... many more operations
}The CpuBackend implements the Backend trait using ndarray for storage and computations:
- Storage Type:
Array(a wrapper aroundndarray::Array) - Performance: Can use system BLAS libraries via the
cpu_openblasfeature - Implementation: Uses vectorized operations where possible
The CudaBackend implements the Backend trait using CUDA for GPU acceleration:
- Storage Type: Custom CUDA memory management
- Kernels: Custom CUDA kernels for operations
- cuBLAS Integration: Uses cuBLAS for matrix operations
- Context Management: Thread-safe CUDA context handling
The library uses dynamic automatic differentiation (autograd):
- Computation Graph: Built dynamically during forward pass
- Reverse-Mode Differentiation: Computes gradients efficiently
- Operation Tracking: Each operation records its inputs and a closure for backward pass
- Gradient Accumulation: Handles multiple paths to the same node*
graph LR
A(Tensor A) --> C(Tensor C)
B(Tensor B) --> C
C --> D(Tensor D)
D -.backward.-> C
C -.backward.-> A
C -.backward.-> B
The hooks system provides a mechanism to monitor and modify tensor operations during both forward and backward passes:
- Forward Hooks: Executed after a tensor operation is performed
- Backward Hooks: Executed when gradients flow through a tensor during the backward pass
- Hook Registration: Hooks are registered with tensors and return a unique ID for later removal
- Hook Interface: All hooks implement the
Hooktrait
pub trait Hook<B: Backend>: Debug {
fn forward(&self, tensor: &Tensor<B>, input: &[&Tensor<B>], output: &Tensor<B>) -> Result<(), Error>;
fn backward(&self, tensor: &Tensor<B>, grad_input: &[Option<Tensor<B>>], grad_output: &Tensor<B>) -> Result<(), Error>;
}- FnHook: A convenience wrapper around closures for quick hook creation
- Custom Hook Types: Users can implement the
Hooktrait for custom behavior
- Debugging: Monitoring intermediate values during forward and backward passes
- Gradient Clipping: Modifying gradients to prevent exploding gradients
- Feature Visualization: Capturing activations in neural networks
- Custom Regularization: Applying custom regularization during the backward pass
The library provides operator overloading for tensors, allowing for more expressive and readable code:
- Arithmetic Operators:
+,-,*,/for element-wise operations - Comparison Operators:
==,!=for tensor equality checks
Operator overloading is implemented for references to tensors to avoid unnecessary moves and clones:
// Implementation for &Tensor + &Tensor
impl<'a, 'b, B: Backend> Add<&'b Tensor<B>> for &'a Tensor<B> {
type Output = Result<Tensor<B>, Error>;
fn add(self, other: &'b Tensor<B>) -> Self::Output {
ops::add(self, other)
}
}
// Implementation for &Tensor * f32 (scalar multiplication)
impl<'a, B: Backend> Mul<f32> for &'a Tensor<B> {
type Output = Result<Tensor<B>, Error>;
fn mul(self, scalar: f32) -> Self::Output {
ops::mul_scalar(self, scalar)
}
}Operator overloading in Rust doesn't support returning Result types directly, so the library uses a pattern where operators return Result<Tensor<B>, Error> which must be unwrapped:
// Using operator overloading with error handling
let c = (&a + &b)?; // Note the ? operator to handle potential errors
// Chaining operations
let result = (&a + &b)?.relu()?;- Readability: Mathematical expressions look more natural
- Expressiveness: Complex operations can be written concisely
- Familiarity: Similar to Python's tensor libraries (PyTorch, NumPy)
- All operators delegate to the corresponding functions in the
opsmodule - Operators are implemented for references to avoid unnecessary cloning
- Both tensor-tensor and tensor-scalar operations are supported
- Broadcasting is handled automatically by the underlying operations
Operations are implemented in a backend-agnostic way:
- Public API: Functions in the
opsmodule that work with any backend - Backend-Specific Implementation: Each operation delegates to the backend
- Gradient Registration: Operations register their backward pass during forward computation
- Broadcasting: Automatic shape broadcasting for compatible operations
- Element-wise Operations: Add, Subtract, Multiply, Divide, etc.
- Matrix Operations: MatMul, Transpose
- Activation Functions: ReLU, Sigmoid, Tanh, etc.
- Reduction Operations: Sum, Mean, Max, etc.
- Shape Operations: View, Expand, Squeeze, etc.
- Neural Network Operations: Conv2D, MaxPool2D, etc.
The library provides several optimizers for neural network training:
- SGD: Basic stochastic gradient descent
- MomentumSGD: SGD with momentum
- Adam: Adaptive moment estimation
- Adagrad: Adaptive gradient algorithm
All optimizers follow a common interface:
pub trait Optimizer<B: Backend> {
fn step(&mut self) -> Result<(), Error>;
fn zero_grad(&mut self) -> Result<(), Error>;
}The library uses feature flags to control optional functionality:
cuda: Enables CUDA GPU supportserialization: Enables tensor serializationmnist: Enables MNIST dataset loading utilitiesdebug_logs: Enables detailed diagnostic loggingcpu_openblas: Enables OpenBLAS acceleration for CPU operations
The library uses a trait-based approach for backends, allowing:
- Code reuse between backends
- Easy addition of new backends
- Backend-agnostic user code
Unlike static frameworks (e.g., TensorFlow 1.x), the library builds the computation graph dynamically:
- More flexible for research and experimentation
- Easier debugging
- More Pythonic/intuitive API
The library uses Rc<RefCell<TensorData>> for tensor data:
- Allows multiple tensors to share the same data
- Enables in-place operations when possible
- Supports efficient view operations without copying data
- Internal mutation is controlled via
RefCell - Graph consistency is maintained via careful operation tracking
- Mutable operations are marked with warning comments
- Error handling uses
Resultthroughout the codebase
- Custom CUDA kernels for critical operations
- cuBLAS integration for matrix operations
- Optional OpenBLAS integration for CPU
- Efficient memory management and reuse
- CPU Backend: Uses Rust's memory management via
ndarray - CUDA Backend: Custom CUDA memory allocation and deallocation
- View Operations: Create new tensors that share underlying storage
- In-place Operations: Modify tensor data in-place when possible
- Result Type: Operations return
Result<T, Error>for robust error handling - Error Types: Specific error variants for different failure modes
- Error Propagation: Uses the
?operator throughout the codebase
- CPU Backend: Thread-safe via Rust's ownership system
- CUDA Backend: Thread-safe via CUDA context management
- Note: The library is not designed for concurrent modification of the same tensor from multiple threads