mindspore.dataset.text.TruncateSequencePair

View Source On AtomGit
class mindspore.dataset.text.TruncateSequencePair(max_length)[source]

Truncate a pair of 1-D string inputs so that their total length is less than the specified length.

Parameters

max_length (int) – The maximum total length of the output strings. If it is no less than the total length of the original pair of strings, no truncation is performed. Otherwise, the longer of the two input strings is truncated until the total length equals this value.

Raises

TypeError – If max_length is not of type int.

Supported Platforms:

CPU

Examples

>>> import mindspore.dataset as ds
>>> import mindspore.dataset.text as text
>>>
>>> # Use the transform in dataset pipeline mode
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=([[1, 2, 3]], [[4, 5]]), column_names=["col1", "col2"])
>>> # Data before
>>> # |   col1    |   col2    |
>>> # +-----------+-----------|
>>> # | [1, 2, 3] |  [4, 5]   |
>>> # +-----------+-----------+
>>> truncate_sequence_pair_op = text.TruncateSequencePair(max_length=4)
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=truncate_sequence_pair_op,
...                                                 input_columns=["col1", "col2"])
>>> for item in numpy_slices_dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
...     print(item["col1"], item["col2"])
[1 2] [4 5]
>>> # Data after
>>> # |   col1    |   col2    |
>>> # +-----------+-----------+
>>> # |  [1, 2]   |  [4, 5]   |
>>> # +-----------+-----------+
>>>
>>> # Use the transform in eager mode
>>> data = [["1", "2", "3"], ["4", "5"]]
>>> output = text.TruncateSequencePair(4)(*data)
>>> print(output)
(array(['1', '2'], dtype='<U1'), array(['4', '5'], dtype='<U1'))
Tutorial Examples: