mindspore.nn.RNNCell

View Source On Gitee
class mindspore.nn.RNNCell(input_size: int, hidden_size: int, has_bias: bool = True, nonlinearity: str = 'tanh', dtype=mstype.float32)[source]

An Elman RNN cell with tanh or ReLU non-linearity.

\[h_t = \tanh(W_{ih} x_t + b_{ih} + W_{hh} h_{(t-1)} + b_{hh})\]

Here \(h_t\) is the hidden state at time t, \(x_t\) is the input at time t, and \(h_{(t-1)}\) is the hidden state of the previous layer at time \(t-1\) or the initial hidden state at time 0. If nonlinearity is relu, then relu is used instead of tanh.

Parameters
  • input_size (int) – Number of features of input.

  • hidden_size (int) – Number of features of hidden layer.

  • has_bias (bool) – Whether the cell has bias \(b_ih\) and \(b_hh\). Default: True .

  • nonlinearity (str) – The non-linearity to use. Can be either "tanh" or "relu" . Default: "tanh" .

  • dtype (mindspore.dtype) – Dtype of Parameters. Default: mstype.float32 .

Inputs:
  • x (Tensor) - Tensor of shape \((batch\_size, input\_size)\) .

  • hx (Tensor) - Tensor of data type mindspore.float32 and shape \((batch\_size, hidden\_size)\) .

Outputs:
  • hx’ (Tensor) - Tensor of shape \((batch\_size, hidden\_size)\) .

Raises
  • TypeError – If input_size or hidden_size is not an int or not greater than 0.

  • TypeError – If has_bias is not a bool.

  • ValueError – If nonlinearity is not in [‘tanh’, ‘relu’].

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore as ms
>>> import numpy as np
>>> net = ms.nn.RNNCell(10, 16)
>>> x = ms.Tensor(np.ones([5, 3, 10]).astype(np.float32))
>>> hx = ms.Tensor(np.ones([3, 16]).astype(np.float32))
>>> output = []
>>> for i in range(5):
...     hx = net(x[i], hx)
...     output.append(hx)
>>> print(output[0].shape)
(3, 16)