mindspore.ops.grid_sample

mindspore.ops.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=False)[source]

Given an input and a flow-field grid, computes the output using input values and pixel locations from grid. Only spatial (4-D) and volumetric (5-D) input is supported.

In the spatial (4-D) case, for input with shape \((N, C, H_{in}, W_{in})\) and grid with shape \((N, H_{out}, W_{out}, 2)\), the output will have shape \((N, C, H_{out}, W_{out})\).

For each output location output[n, :, h, w], the size-2 vector grid[n, h, w] specifies input pixel locations x and y, which are used to interpolate the output value output[n, :, h, w]. In the case of 5D inputs, grid[n, d, h, w], specifies the x, y, z pixel locations for interpolating output[n, :, d, h, w]. And mode argument specifies “nearest” or “bilinear” (“bicubic” is not supported yet) interpolation method to sample the input pixels.

grid specifies the sampling pixel locations normalized by the input spatial dimensions. Therefore, it should have most values in the range of \([-1, 1]\).

If grid has values outside the range of \([-1, 1]\), the corresponding outputs are handled as defined by padding_mode. If padding_mode is set to be “zeros”, use \(0\) for out-of-bound grid locations. If padding_mode is set to be “border”, use border values for out-of-bound grid locations. If padding_mode is set to be “reflection”, use values at locations reflected by the border for out-of-bound grid locations. For location far away from the border, it will keep being reflected until becoming in bound.

Parameters
  • input (Tensor) – input with shape of \((N, C, H_{in}, W_{in})\) (4-D case) or \((N, C, D_{in}, H_{in}, W_{in})\) (5-D case) and dtype of float32 or float64.

  • grid (Tensor) – flow-field with shape of \((N, H_{out}, W_{out}, 2)\) (4-D case) or \((N, D_{out}, H_{out}, W_{out}, 3)\) (5-D case) and same dtype as input.

  • mode (str) – An optional string specifying the interpolation method. The optional values are “bilinear”, “nearest”. Default: “bilinear”. Note: “bicubic” is not supported yet. When mode=”bilinear” and the input is 5-D, the interpolation mode used internally will actually be trilinear. However, when the input is 4-D, the interpolation mode will legistimately be bilinear. Default: ‘bilinear’.

  • padding_mode (str) – An optional string specifying the pad method. The optional values are “zeros”, “border” or “reflection”. Default: ‘zeros’.

  • align_corners (bool) – An optional bool. If set to True, the extrema (-1 and 1) are considered as referring to the center points of the input’s corner pixels. If set to False, they are instead considered as referring to the corner points of the input’s corner pixels, making the sampling more resolution agnostic. Default: False.

Returns

Tensor, dtype is the same as input and whose shape is \((N, C, H_{out}, W_{out})\) (4-D) and \((N, C, D_{out}, H_{out}, W_{out})\) (5-D).

Raises
  • TypeError – If input or grid is not a Tensor.

  • TypeError – If the dtypes of input and grid are inconsistent.

  • TypeError – If the dtype of input or grid is not a valid type.

  • TypeError – If align_corners is not a boolean value.

  • ValueError – If the rank of input or grid is not equal to 4(4-D case) or 5(5-D case).

  • ValueError – If the first dimension of input is not equal to that of grid.

  • ValueError – If the last dimension of grid is not equal to 2(4-D case) or 3(5-D case).

  • ValueError – If mode is not “bilinear”, “nearest” or a string value.

  • ValueError – If padding_mode is not “zeros”, “border”, “reflection” or a string value.

Supported Platforms:

Ascend GPU CPU

Examples

>>> input_x = Tensor(np.arange(16).reshape((2, 2, 2, 2)).astype(np.float32))
>>> grid = Tensor(np.arange(0.2, 1, 0.1).reshape((2, 2, 1, 2)).astype(np.float32))
>>> output = ops.grid_sample(input_x, grid, mode='bilinear', padding_mode='zeros',
...                          align_corners=True)
>>> print(output)
[[[[ 1.9      ]
   [ 2.1999998]]
  [[ 5.9      ]
   [ 6.2      ]]]
 [[[10.5      ]
   [10.8      ]]
  [[14.5      ]
   [14.8      ]]]]