|
1 | 1 | // SPDX-License-Identifier: MIT or Apache-2.0 |
2 | 2 | // First Published under RadixMLP and https://github.com/michaelfeil/candle-index-select-cu by Michael Feil |
3 | 3 |
|
4 | | -use candle::{DType, Result, Tensor}; |
| 4 | +use candle::{Result, Tensor}; |
| 5 | + |
| 6 | +#[cfg(feature = "cuda")] |
| 7 | +use candle::DType; |
5 | 8 | #[cfg(feature = "cuda")] |
6 | 9 | use candle_index_select_cu; |
7 | 10 |
|
8 | 11 | #[inline] |
9 | 12 | #[allow(dead_code)] |
10 | 13 | pub fn index_select(tensor: &Tensor, ids: &Tensor, dim: usize) -> Result<Tensor> { |
11 | | - if cfg!(feature = "cuda") |
12 | | - && matches!(tensor.dtype(), DType::F16 | DType::F32) |
13 | | - && matches!(ids.dtype(), DType::U32) |
| 14 | + #[cfg(feature = "cuda")] |
| 15 | + { |
| 16 | + if matches!(tensor.dtype(), DType::F16 | DType::F32) && matches!(ids.dtype(), DType::U32) { |
| 17 | + // NOTE: `candle-index-select-cu` supports f16/f32 data and u32 indices |
| 18 | + candle_index_select_cu::index_select(tensor, ids, dim) |
| 19 | + } else { |
| 20 | + tensor.index_select(ids, dim) |
| 21 | + } |
| 22 | + } |
| 23 | + #[cfg(not(feature = "cuda"))] |
14 | 24 | { |
15 | | - // NOTE: `candle-index-select-cu` supports f16/f32 data and u32 indices |
16 | | - candle_index_select_cu::index_select(tensor, ids, dim) |
17 | | - } else { |
18 | 25 | tensor.index_select(ids, dim) |
19 | 26 | } |
20 | 27 | } |
0 commit comments