Skip to content

Commit c19087e

Browse files
committed
Fix index_select feature gating
1 parent f6880c7 commit c19087e

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed
Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,27 @@
11
// SPDX-License-Identifier: MIT or Apache-2.0
22
// First Published under RadixMLP and https://github.com/michaelfeil/candle-index-select-cu by Michael Feil
33

4-
use candle::{DType, Result, Tensor};
4+
use candle::{Result, Tensor};
5+
6+
#[cfg(feature = "cuda")]
7+
use candle::DType;
58
#[cfg(feature = "cuda")]
69
use candle_index_select_cu;
710

811
#[inline]
912
#[allow(dead_code)]
1013
pub fn index_select(tensor: &Tensor, ids: &Tensor, dim: usize) -> Result<Tensor> {
11-
if cfg!(feature = "cuda")
12-
&& matches!(tensor.dtype(), DType::F16 | DType::F32)
13-
&& matches!(ids.dtype(), DType::U32)
14+
#[cfg(feature = "cuda")]
15+
{
16+
if matches!(tensor.dtype(), DType::F16 | DType::F32) && matches!(ids.dtype(), DType::U32) {
17+
// NOTE: `candle-index-select-cu` supports f16/f32 data and u32 indices
18+
candle_index_select_cu::index_select(tensor, ids, dim)
19+
} else {
20+
tensor.index_select(ids, dim)
21+
}
22+
}
23+
#[cfg(not(feature = "cuda"))]
1424
{
15-
// NOTE: `candle-index-select-cu` supports f16/f32 data and u32 indices
16-
candle_index_select_cu::index_select(tensor, ids, dim)
17-
} else {
1825
tensor.index_select(ids, dim)
1926
}
2027
}

0 commit comments

Comments
 (0)