mindspore.nn.GLU

View Source On Gitee
class mindspore.nn.GLU(axis=- 1)[source]

The gated linear unit function.

\[{GLU}(a, b)= a \otimes \sigma(b)\]

where \(a\) is the first half of the input matrices and \(b\) is the second half.

Here \(\sigma\) is the sigmoid function, and \(\otimes\) is the Hadamard product.

Parameters

axis (int) – the axis to split the input. Default: -1 , the last axis in x.

Inputs:
  • x (Tensor) - \((\ast_1, N, \ast_2)\) where * means, any number of additional dimensions.

Outputs:

Tensor, the same dtype as the x, with the shape \((\ast_1, M, \ast_2)\) where \(M=N/2\).

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore as ms
>>> m = ms.nn.GLU()
>>> input = ms.Tensor([[0.1,0.2,0.3,0.4],[0.5,0.6,0.7,0.8]])
>>> output = m(input)
>>> print(output)
[[0.05744425 0.11973753]
 [0.33409387 0.41398472]]