Skip to content
Merged
41 changes: 31 additions & 10 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ members = [
"tensor-tools",
]
exclude = [
"candle-book",
"candle-flash-attn",
"candle-kernels",
"candle-metal-kernels",
"candle-onnx",
"candle-book",
"candle-flash-attn",
"candle-kernels",
"candle-metal-kernels",
"candle-onnx",
]
resolver = "2"

Expand All @@ -42,14 +42,35 @@ candle-nn = { path = "./candle-nn", version = "0.9.1" }
candle-onnx = { path = "./candle-onnx", version = "0.9.1" }
candle-transformers = { path = "./candle-transformers", version = "0.9.1" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.16.3", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
criterion = { version = "0.5.1", default-features = false }
cudarc = { version = "0.16.3", features = [
"std",
"cublas",
"cublaslt",
"curand",
"driver",
"nvrtc",
"f16",
"cuda-version-from-build-system",
"dynamic-linking",
], default-features = false }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.4.1"
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] }
half = { version = "2.5.0", features = [
"num-traits",
"use-intrinsics",
"rand_distr",
] }
float8 = { git = "https://github.com/zackangelo/float8", branch = "cudarc_0_16", features = [
"num-traits",
"rand_distr",
] }
hound = "3.5.1"
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
image = { version = "0.25.2", default-features = false, features = [
"jpeg",
"png",
] }
imageproc = { version = "0.24.0", default-features = false }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
libc = { version = "0.2.147" }
Expand All @@ -75,7 +96,7 @@ ug-cuda = "0.4.0"
ug-metal = "0.4.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "1.1.1", default-features = false }
metal = { version = "0.27.0", features = ["mps"]}
metal = { version = "0.27.0", features = ["mps"] }

[profile.release-with-debug]
inherits = "release"
Expand Down
3 changes: 2 additions & 1 deletion candle-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ metal = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
half = { workspace = true }
float8 = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
libc = { workspace = true, optional = true }
memmap2 = { workspace = true }
Expand All @@ -43,7 +44,7 @@ criterion = { workspace = true }

[features]
default = []
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda", "float8/cuda"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
Expand Down
6 changes: 6 additions & 0 deletions candle-core/src/convert.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Implement conversion traits for tensors
use crate::{DType, Device, Error, Tensor, WithDType};
use float8::F8E4M3;
use half::{bf16, f16, slice::HalfFloatSliceExt};
use std::convert::TryFrom;

Expand Down Expand Up @@ -139,6 +140,11 @@ impl Tensor {
let vs = vs.to_vec1::<u8>()?;
f.write_all(&vs)?;
}
DType::F8E4M3 => {
for v in vs.to_vec1::<F8E4M3>()? {
f.write_u8(v.to_bits())?
}
}
}
Ok(())
}
Expand Down
118 changes: 118 additions & 0 deletions candle-core/src/cpu_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
use float8::F8E4M3;
use half::{bf16, f16};
use rayon::prelude::*;

Expand All @@ -25,6 +26,7 @@ pub enum CpuStorage {
F16(Vec<f16>),
F32(Vec<f32>),
F64(Vec<f64>),
F8E4M3(Vec<F8E4M3>),
}

#[derive(Debug, Clone)]
Expand All @@ -36,6 +38,7 @@ pub enum CpuStorageRef<'a> {
F16(&'a [f16]),
F32(&'a [f32]),
F64(&'a [f64]),
F8E4M3(&'a [F8E4M3]),
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -1691,6 +1694,17 @@ impl CpuStorage {
.concat();
Self::F64(storages)
}
Self::F8E4M3(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::F8E4M3(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::F8E4M3(storages)
}
};
Ok(s)
}
Expand All @@ -1708,6 +1722,7 @@ impl BackendStorage for CpuStorage {
Self::F16(_) => DType::F16,
Self::F32(_) => DType::F32,
Self::F64(_) => DType::F64,
Self::F8E4M3(_) => DType::F8E4M3,
}
}

Expand Down Expand Up @@ -1742,6 +1757,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, bf16::from_f64);
Ok(Self::BF16(data))
}
(Self::F8E4M3(storage), DType::BF16) => {
let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32()));
Ok(Self::BF16(data))
}
(Self::U8(storage), DType::F16) => {
let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
Ok(Self::F16(data))
Expand Down Expand Up @@ -1770,6 +1789,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, f16::from_f64);
Ok(Self::F16(data))
}
(Self::F8E4M3(storage), DType::F16) => {
let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
Ok(Self::F16(data))
}
(Self::U8(storage), DType::F32) => {
let data = unary_map(storage, layout, |v| v as f32);
Ok(Self::F32(data))
Expand Down Expand Up @@ -1798,6 +1821,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as f32);
Ok(Self::F32(data))
}
(Self::F8E4M3(storage), DType::F32) => {
let data = unary_map(storage, layout, |v| v.to_f32());
Ok(Self::F32(data))
}
(Self::U8(storage), DType::U8) => {
let data = unary_map(storage, layout, |v| v);
Ok(Self::U8(data))
Expand Down Expand Up @@ -1826,6 +1853,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as u8);
Ok(Self::U8(data))
}
(Self::F8E4M3(storage), DType::U8) => {
let data = unary_map(storage, layout, |v| v.to_f32() as u8);
Ok(Self::U8(data))
}
(Self::U8(storage), DType::U32) => {
let data = unary_map(storage, layout, |v| v as u32);
Ok(Self::U32(data))
Expand Down Expand Up @@ -1854,6 +1885,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as u32);
Ok(Self::U32(data))
}
(Self::F8E4M3(storage), DType::U32) => {
let data = unary_map(storage, layout, |v| v.to_f32() as u32);
Ok(Self::U32(data))
}
(Self::U8(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
Expand Down Expand Up @@ -1882,6 +1917,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
}
(Self::F8E4M3(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v.to_f32() as i64);
Ok(Self::I64(data))
}
(Self::U8(storage), DType::F64) => {
let data = unary_map(storage, layout, |v| v as f64);
Ok(Self::F64(data))
Expand Down Expand Up @@ -1910,6 +1949,42 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v);
Ok(Self::F64(data))
}
(Self::F8E4M3(storage), DType::F64) => {
let data = unary_map(storage, layout, |v| v.to_f64());
Ok(Self::F64(data))
}
(Self::U8(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
Ok(Self::F8E4M3(data))
}
(Self::U32(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
Ok(Self::F8E4M3(data))
}
(Self::I64(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
Ok(Self::F8E4M3(data))
}
(Self::BF16(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| F8E4M3::from(v.to_f32()));
Ok(Self::F8E4M3(data))
}
(Self::F16(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v.to_f32()));
Ok(Self::F8E4M3(data))
}
(Self::F32(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, F8E4M3::from_f32);
Ok(Self::F8E4M3(data))
}
(Self::F64(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, F8E4M3::from_f64);
Ok(Self::F8E4M3(data))
}
(Self::F8E4M3(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| v);
Ok(Self::F8E4M3(data))
}
}
}

Expand Down Expand Up @@ -2023,6 +2098,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v.powf(e));
Ok(Self::F64(data))
}
Self::F8E4M3(storage) => {
let data = unary_map(storage, layout, |v| v.powf(F8E4M3::from_f64(e)));
Ok(Self::F8E4M3(data))
}
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
Expand All @@ -2048,6 +2127,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| elu(v, alpha));
Ok(Self::F64(data))
}
Self::F8E4M3(storage) => {
let data = unary_map(storage, layout, |v| elu(v, F8E4M3::from_f64(alpha)));
Ok(Self::F8E4M3(data))
}
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
Expand Down Expand Up @@ -2092,6 +2175,15 @@ impl BackendStorage for CpuStorage {
Ok(Self::F64(data))
}
}
Self::F8E4M3(storage) => {
if B::F8E4M3_VEC {
let data = unary_map_vec(storage, layout, B::f8e4m3, B::f8e4m3_vec);
Ok(Self::F8E4M3(data))
} else {
let data = unary_map(storage, layout, B::f8e4m3);
Ok(Self::F8E4M3(data))
}
}
Self::U8(storage) => {
let data = unary_map(storage, layout, B::u8);
Ok(Self::U8(data))
Expand Down Expand Up @@ -2564,6 +2656,7 @@ impl BackendStorage for CpuStorage {
(Self::U8(storage), Scalar::U8(v)) => set(storage, l, v),
(Self::U32(storage), Scalar::U32(v)) => set(storage, l, v),
(Self::I64(storage), Scalar::I64(v)) => set(storage, l, v),
(Self::F8E4M3(storage), Scalar::F8E4M3(v)) => set(storage, l, v),
(st, s) => crate::bail!(
"const_set dtype mismatch, expected {:?} but got {:?}",
st.dtype(),
Expand Down Expand Up @@ -2632,6 +2725,16 @@ impl BackendDevice for CpuDevice {
}
Ok(CpuStorage::F16(data))
}
DType::F8E4M3 => {
let mut data = Vec::with_capacity(elem_count);
let uniform =
rand::distr::Uniform::new(F8E4M3::from_f64(min), F8E4M3::from_f64(max))
.map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(rng.sample::<F8E4M3, _>(uniform))
}
Ok(CpuStorage::F8E4M3(data))
}
DType::F32 => {
let mut data = Vec::with_capacity(elem_count);
let uniform =
Expand Down Expand Up @@ -2679,6 +2782,15 @@ impl BackendDevice for CpuDevice {
}
Ok(CpuStorage::F16(data))
}
DType::F8E4M3 => {
let mut data = Vec::with_capacity(elem_count);
let normal = rand_distr::Normal::new(F8E4M3::from_f64(mean), F8E4M3::from_f64(std))
.map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::F8E4M3(data))
}
DType::F32 => {
let mut data = Vec::with_capacity(elem_count);
let normal =
Expand Down Expand Up @@ -2742,6 +2854,11 @@ impl BackendDevice for CpuDevice {
v.set_len(elem_count);
CpuStorage::F64(v)
}
DType::F8E4M3 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::F8E4M3(v)
}
};
Ok(storage)
}
Expand All @@ -2754,6 +2871,7 @@ impl BackendDevice for CpuDevice {
DType::I64 => CpuStorage::I64(vec![0i64; elem_count]),
DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]),
DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ZERO; elem_count]),
DType::F32 => CpuStorage::F32(vec![0f32; elem_count]),
DType::F64 => CpuStorage::F64(vec![0f64; elem_count]),
};
Expand Down
4 changes: 4 additions & 0 deletions candle-core/src/cpu_backend/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub trait Map1 {
C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)),
C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)),
C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)),
C::F8E4M3(vs) => Ok(C::F8E4M3(self.f(vs, layout)?)),
}
}
}
Expand All @@ -31,6 +32,7 @@ pub trait Map1Any {
C::F16(vs) => Ok(self.f(vs, layout, C::F16)?),
C::F32(vs) => Ok(self.f(vs, layout, C::F32)?),
C::F64(vs) => Ok(self.f(vs, layout, C::F64)?),
C::F8E4M3(vs) => Ok(self.f(vs, layout, C::F8E4M3)?),
}
}
}
Expand All @@ -48,6 +50,7 @@ pub trait Map2 {
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
(C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::F8E4M3(self.f(v1, l1, v2, l2)?)),
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
Expand Down Expand Up @@ -95,6 +98,7 @@ pub trait Map2U8 {
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
Expand Down
Loading