Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions candle-core/src/metal_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1303,8 +1303,42 @@ impl BackendStorage for MetalStorage {
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
}

fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
crate::bail!("Metal upsample_nearest1d not implemented")
fn upsample_nearest1d(&self, inp_l: &Layout, out_w: usize) -> Result<Self> {
let shape = inp_l.shape();
let dims = shape.dims();
let strides = inp_l.stride();
if dims.len() != 3 {
crate::bail!("unexpected input shape for upsample1d {dims:?}")
}
let name = match self.dtype {
DType::F32 => "upsample_nearest1d_f32",
DType::F16 => "upsample_nearest1d_f16",
DType::BF16 => "upsample_nearest1d_bf16",
DType::U8 => "upsample_nearest1d_u8",
DType::U32 => "upsample_nearest1d_u32",
dtype => crate::bail!("Metal upsample_nearest1d {dtype:?} not implemented"),
};

let dst_el = out_w * dims[0] * dims[1];
let buffer = self
.device
.new_buffer(dst_el, self.dtype, "upsample_nearest1d")?;
let encoder = self.device.command_encoder()?;
encoder.set_label("upsample_nearest1d");
let src = buffer_o(&self.buffer, inp_l, self.dtype);
candle_metal_kernels::call_upsample_nearest_1d(
&self.device.device,
&encoder,
&self.device.kernels,
name,
dims,
strides,
out_w,
src,
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
}

fn upsample_nearest2d(&self, inp_l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
Expand Down
14 changes: 14 additions & 0 deletions candle-core/tests/pool_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
Ok(())
}

fn upsample_nearest1d(dev: &Device) -> Result<()> {
let t = Tensor::arange(0f32, 3f32, dev)?.reshape((1, 1, 3))?;
let upsampled = t.upsample_nearest1d(6)?.i(0)?.i(0)?;
assert_eq!(t.i(0)?.i(0)?.to_vec1::<f32>()?, [0.0, 1.0, 2.0]);
assert_eq!(upsampled.to_vec1::<f32>()?, [0.0, 0.0, 1.0, 1.0, 2.0, 2.0]);
Ok(())
}

fn upsample_nearest2d(dev: &Device) -> Result<()> {
let t = Tensor::arange(0f32, 6f32, dev)?.reshape((1, 1, 2, 3))?;
let upsampled = t.upsample_nearest2d(4, 6)?.i(0)?.i(0)?;
Expand Down Expand Up @@ -109,6 +117,12 @@ test_device!(
avg_pool2d_pytorch_metal
);
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu, max_pool2d_metal);
test_device!(
upsample_nearest1d,
upsample_nearest1d_cpu,
upsample_nearest1d_gpu,
upsample_nearest1d_metal
);
test_device!(
upsample_nearest2d,
upsample_nearest2d_cpu,
Expand Down
27 changes: 27 additions & 0 deletions candle-metal-kernels/src/kernels/convolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,33 @@ pub fn call_im2col_strided(
Ok(())
}

#[allow(clippy::too_many_arguments)]
pub fn call_upsample_nearest_1d(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
strides: &[usize],
out_w: usize,
input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
let dst_el = out_w * shape[0] * shape[1];
let scale_w = shape[2] as f32 / out_w as f32;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoder = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(out_w, scale_w, shape, strides, &input, Output::new(output))
);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}

#[allow(clippy::too_many_arguments)]
pub fn call_upsample_nearest_2d(
device: &Device,
Expand Down
47 changes: 47 additions & 0 deletions candle-metal-kernels/src/metal_src/conv.metal
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,32 @@ METAL_FUNC void im2col1d(
}
}

template <typename T>
METAL_FUNC void upsample_nearest1d(
constant size_t &w_out,
constant float &w_scale,
constant size_t *src_dims,
constant size_t *src_s,
device const T *src,
device T *dst,
uint tid
) {
const size_t c = src_dims[1];
const size_t w_in = src_dims[2];

if (tid >= src_dims[0] * c * w_out) return;

const size_t b_idx = tid / (w_out * c);
const size_t c_idx = (tid / w_out) % c;
const size_t dst_w = tid % w_out;

size_t src_w = static_cast<size_t>(dst_w * w_scale);
if (src_w >= w_in) src_w = w_in - 1;

const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_w * src_s[2];
dst[tid] = src[src_i];
}

template <typename T>
METAL_FUNC void upsample_nearest2d(
constant size_t &w_out,
Expand Down Expand Up @@ -334,6 +360,19 @@ kernel void FN_NAME( \
col2im1d<T>(dst_el, l_out, l_in, c_out, k_size, stride, src, dst, tid); \
} \

#define UPSAMPLE_NEAREST1D_OP(TYPENAME, FN_NAME) \
kernel void FN_NAME( \
constant size_t &w_out, \
constant float &w_scale, \
constant size_t *dims, \
constant size_t *strides, \
device const TYPENAME *src, \
device TYPENAME *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
upsample_nearest1d<TYPENAME>(w_out, w_scale, dims, strides, src, dst, tid); \
} \

#define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \
kernel void FN_NAME( \
constant size_t &w_out, \
Expand Down Expand Up @@ -670,6 +709,14 @@ IM2COL1D_OP(uint32_t, im2col1d_u32)
IM2COL1D_OP(bfloat, im2col1d_bf16)
#endif

UPSAMPLE_NEAREST1D_OP(float, upsample_nearest1d_f32)
UPSAMPLE_NEAREST1D_OP(half, upsample_nearest1d_f16)
UPSAMPLE_NEAREST1D_OP(uint8_t, upsample_nearest1d_u8)
UPSAMPLE_NEAREST1D_OP(uint32_t, upsample_nearest1d_u32)
#if defined(__HAVE_BFLOAT__)
UPSAMPLE_NEAREST1D_OP(bfloat, upsample_nearest1d_bf16)
#endif

UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32)
UPSAMPLE_NEAREST2D_OP(half, upsample_nearest2d_f16)
UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8)
Expand Down