Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions candle-nn/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,15 @@ pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result<Tensor> {
xs.maximum(&zeros)? + xs.minimum(&zeros)? * negative_slope
}

pub fn selu(xs: &Tensor, alpha: f32, gamma: f32) -> Result<Tensor> {
let is_pos = xs.gt(0f32)?;
let alpha_t = Tensor::full(alpha, xs.dims(), xs.device())?;
let neg = xs.exp()?.mul(&alpha_t)?.sub(&alpha_t)?;
let selu = is_pos.where_cond(xs, &neg)?;
let gamma_t = Tensor::full(gamma, xs.dims(), xs.device())?;
selu.broadcast_mul(&gamma_t)
}

pub fn dropout(xs: &Tensor, drop_p: f32) -> Result<Tensor> {
// This implementation is inefficient as it stores the full mask for the backward pass.
// Instead we could just store the seed and have a specialized kernel that would both
Expand Down
13 changes: 13 additions & 0 deletions candle-onnx/src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2097,6 +2097,19 @@ fn simple_eval_(
let output = input.sign()?;
values.insert(node.output[0].clone(), output);
}
// https://onnx.ai/onnx/operators/onnx__Selu.html
"Selu" => {
let input = get(&node.input[0])?;
let alpha = get_attr_opt::<f32>(node, "alpha")?
.copied()
.unwrap_or(1.6732632);
let gamma = get_attr_opt::<f32>(node, "gamma")?
.copied()
.unwrap_or(1.050701);
let out = candle_nn::ops::selu(input, alpha as f32, gamma as f32)?;
values.insert(node.output[0].clone(), out);
}

// https://onnx.ai/onnx/operators/onnx__OneHot.html
"OneHot" => {
let indices = get(&node.input[0])?;
Expand Down
194 changes: 194 additions & 0 deletions candle-onnx/tests/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6248,6 +6248,200 @@ fn test_sign_operation() -> Result<()> {
}

#[test]
fn test_selu_operator() -> Result<()> {
{
// Test 1: Default alpha and gamma
let default_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Selu".to_string(),
domain: "".to_string(),
input: vec!["input".to_string()],
output: vec!["output".to_string()],
..Default::default()
}],
input: vec![ValueInfoProto {
name: "input".to_string(),
..Default::default()
}],
output: vec![ValueInfoProto {
name: "output".to_string(),
r#type: None,
..Default::default()
}],
..Default::default()
}));

let input = Tensor::from_vec(vec![-1.0f32, 0.0, 1.0, 2.0], (2, 2), &Device::Cpu)?;
let mut inputs = HashMap::new();
inputs.insert("input".to_string(), input);

let eval = simple_eval(&default_graph, inputs)?;
let output = eval.get("output").unwrap();
let out_vec = to_vec2_round(output, 4)?;
assert_eq!(out_vec, vec![vec![-1.1113, 0.0], vec![1.0507, 2.1014]]);
}

{
// Test 2: Change alpha and gamma
let custom_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Selu".to_string(),
attribute: vec![
AttributeProto {
name: "alpha".to_string(),
r#type: AttributeType::Float as i32,
f: 2.0,
..Default::default()
},
AttributeProto {
name: "gamma".to_string(),
r#type: AttributeType::Float as i32,
f: 0.5,
..Default::default()
},
],
input: vec!["input".to_string()],
output: vec!["output".to_string()],
..Default::default()
}],
input: vec![ValueInfoProto {
name: "input".to_string(),
..Default::default()
}],
output: vec![ValueInfoProto {
name: "output".to_string(),
..Default::default()
}],
..Default::default()
}));

let input = Tensor::from_vec(vec![-1.0f32, 0.0, 1.0, 2.0], (2, 2), &Device::Cpu)?;
let mut inputs = HashMap::new();
inputs.insert("input".to_string(), input);
let eval = simple_eval(&custom_graph, inputs)?;
let output = eval.get("output").unwrap();
let out_vec = to_vec2_round(output, 4)?;
assert_eq!(out_vec, vec![vec![-0.6321, 0.0], vec![0.5, 1.0]]);
}

{
// Test 3: Different input values
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Selu".to_string(),
domain: "".to_string(),
input: vec!["input".to_string()],
output: vec!["output".to_string()],
..Default::default()
}],
input: vec![ValueInfoProto {
name: "input".to_string(),
..Default::default()
}],
output: vec![ValueInfoProto {
name: "output".to_string(),
..Default::default()
}],
..Default::default()
}));

let expected = vec![-1.758, -1.7463, 0.0, 10.507];

let input = Tensor::from_vec(vec![-10.0f32, -5.0, 0.0, 10.0], (2, 2), &Device::Cpu)?;
let mut inputs = HashMap::new();
inputs.insert("input".to_string(), input);
let eval = simple_eval(&manual_graph, inputs)?;
let output = eval.get("output").unwrap();
let out_vec = to_vec2_round(output, 4)?;
assert_eq!(
out_vec,
vec![
vec![expected[0], expected[1]],
vec![expected[2], expected[3]]
]
);
}

{
// Test 4: Test based on https://github.com/onnx/onnx/blob/main/docs/Operators.md#Selu
let graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Selu".to_string(),
input: vec!["input".to_string()],
output: vec!["output".to_string()],
attribute: vec![
AttributeProto {
name: "alpha".to_string(),
r#type: AttributeType::Float as i32,
f: 2.0,
..Default::default()
},
AttributeProto {
name: "gamma".to_string(),
r#type: AttributeType::Float as i32,
f: 3.0,
..Default::default()
},
],
..Default::default()
}],
input: vec![ValueInfoProto {
name: "input".to_string(),
..Default::default()
}],
output: vec![ValueInfoProto {
name: "output".to_string(),
..Default::default()
}],
..Default::default()
}));

let input = Tensor::from_vec(vec![-1.0f32, 0.0, 1.0], (3,), &Device::Cpu)?;
let mut inputs = HashMap::new();
inputs.insert("input".to_string(), input);

let eval = simple_eval(&graph, inputs)?;
let output = eval.get("output").unwrap();
let out_vec = output.to_vec1::<f32>()?;
let expected = vec![-3.7927232, 0.0, 3.0];

for (o, e) in out_vec.iter().zip(expected.iter()) {
assert!((o - e).abs() < 1e-5, "Got {o}, expected {e}");
}
}

{
// Test 5: Empty tensor
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Selu".to_string(),
domain: "".to_string(),
input: vec!["input".to_string()],
output: vec!["output".to_string()],
..Default::default()
}],
input: vec![ValueInfoProto {
name: "input".to_string(),
..Default::default()
}],
output: vec![ValueInfoProto {
name: "output".to_string(),
..Default::default()
}],
..Default::default()
}));

let input = Tensor::from_vec(vec![] as Vec<f32>, (0, 2), &Device::Cpu)?;
let mut inputs = HashMap::new();
inputs.insert("input".to_string(), input);
let eval = simple_eval(&manual_graph, inputs)?;
let output = eval.get("output").unwrap();
assert_eq!(output.dims(), &[0, 2]);
}

Ok(())
}

fn test_hard_swish() -> candle::Result<()> {
{
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
Expand Down
Loading