Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions lib/axon/activations.ex
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,60 @@ defmodule Axon.Activations do
|> Nx.multiply(x)
end

@doc ~S"""
Swish-gated linear unit activation.

SwiGLU splits the input tensor along the given axis into two equal
halves $a$ and $b$, then returns $silu(a) \odot b$. The dimension of
the input along the given axis must be divisible by 2.

$$f(x) = silu(a) \odot b \quad \text{where} \quad x = [a, b]$$

## Options

* `:axis` - axis along which to split the input into the activation
and gate halves. Defaults to `-1`.

## Examples

iex> Axon.Activations.swiglu(Nx.tensor([[-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0]], names: [:batch, :data]))
#Nx.Tensor<
f32[batch: 1][data: 4]
[
[-0.14227762818336487, -0.4768116772174835, -0.8068243265151978, 0.0]
]
>

### Error cases

iex> Axon.Activations.swiglu(Nx.tensor([1.0, 2.0, 3.0]))
** (ArgumentError) axis -1 of input to swiglu must have a dimension divisible by 2, got dimension of size 3

## References
* [GLU Variants Improve Transformer](https://arxiv.org/abs/2002.05202)
"""
defn swiglu(x, opts \\ []) do
opts = keyword!(opts, axis: -1)
{a, b} = split_halves(x, opts[:axis])
silu(a) * b

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seanmor5 how about adding our first Nx.block here?

end

deftransformp split_halves(x, axis) do
axis_idx = Nx.axis_index(x, axis)
size = elem(Nx.shape(x), axis_idx)

if rem(size, 2) != 0 do
raise ArgumentError,
"axis #{inspect(axis)} of input to swiglu must have a dimension" <>
" divisible by 2, got dimension of size #{size}"
end

half = div(size, 2)
a = Nx.slice_along_axis(x, 0, half, axis: axis_idx)
b = Nx.slice_along_axis(x, half, half, axis: axis_idx)
{a, b}
end

@doc ~S"""
Hyperbolic tangent activation.

Expand Down
2 changes: 1 addition & 1 deletion lib/axon/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2243,7 +2243,7 @@ defmodule Axon.Layers do
# to use when invoking activation layers.
@activation_layers [:exp, :gelu, :hard_tanh, :linear, :log_sigmoid] ++
[:mish, :relu, :relu6, :sigmoid, :silu, :softplus] ++
[:softsign, :tanh]
[:softsign, :tanh, :swiglu]

for activation <- @activation_layers do
@doc false
Expand Down
Loading