:gitee_url: https://gitee.com/mindspore/docs .. _program_listing_file_include_samplers.h: Program Listing for File samplers.h =================================== |exhale_lsh| :ref:`Return to documentation for file ` (``include/samplers.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_ #include #include #include "include/api/types.h" #include "include/api/status.h" namespace mindspore { namespace dataset { // Forward declare class SamplerObj; // Abstract class to represent a sampler in the data pipeline. class Sampler : std::enable_shared_from_this { friend class AlbumDataset; friend class Caltech256Dataset; friend class CelebADataset; friend class Cifar10Dataset; friend class Cifar100Dataset; friend class CityscapesDataset; friend class CLUEDataset; friend class CMUArcticDataset; friend class CocoDataset; friend class CSVDataset; friend class DIV2KDataset; friend class EMnistDataset; friend class FakeImageDataset; friend class FashionMnistDataset; friend class FlickrDataset; friend class GTZANDataset; friend class ImageFolderDataset; friend class IMDBDataset; friend class KITTIDataset; friend class KMnistDataset; friend class LFWDataset; friend class LibriTTSDataset; friend class LJSpeechDataset; friend class LSUNDataset; friend class ManifestDataset; friend class MindDataDataset; friend class MnistDataset; friend class OmniglotDataset; friend class PhotoTourDataset; friend class Places365Dataset; friend class QMnistDataset; friend class RandomDataDataset; friend class SBUDataset; friend class SemeionDataset; friend class SpeechCommandsDataset; friend class STL10Dataset; friend class TedliumDataset; friend class TextFileDataset; friend class TFRecordDataset; friend class USPSDataset; friend class VOCDataset; friend class WIDERFaceDataset; friend class YesNoDataset; friend std::shared_ptr SelectSampler(int64_t, bool, int32_t, int32_t); public: Sampler() = default; virtual ~Sampler() = default; virtual void AddChild(const std::shared_ptr &child) { children_.push_back(child); } protected: virtual std::shared_ptr Parse() const = 0; Status BuildChildren(std::shared_ptr *const sampler) const; std::vector> children_; }; class DistributedSampler final : public Sampler { friend std::shared_ptr SelectSampler(int64_t, bool, int32_t, int32_t); public: DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle = true, int64_t num_samples = 0, uint32_t seed = 1, int64_t offset = -1, bool even_dist = true); ~DistributedSampler() = default; protected: std::shared_ptr Parse() const override; private: int64_t num_shards_; int64_t shard_id_; bool shuffle_; int64_t num_samples_; uint32_t seed_; int64_t offset_; bool even_dist_; }; class PKSampler final : public Sampler { friend std::shared_ptr SelectSampler(int64_t, bool, int32_t, int32_t); public: explicit PKSampler(int64_t num_val, bool shuffle = false, int64_t num_samples = 0); ~PKSampler() = default; protected: std::shared_ptr Parse() const override; private: int64_t num_val_; bool shuffle_; int64_t num_samples_; }; class RandomSampler final : public Sampler { friend std::shared_ptr SelectSampler(int64_t, bool, int32_t, int32_t); public: explicit RandomSampler(bool replacement = false, int64_t num_samples = 0); ~RandomSampler() = default; protected: std::shared_ptr Parse() const override; private: bool replacement_; int64_t num_samples_; }; class SequentialSampler final : public Sampler { friend std::shared_ptr SelectSampler(int64_t, bool, int32_t, int32_t); public: explicit SequentialSampler(int64_t start_index = 0, int64_t num_samples = 0); ~SequentialSampler() = default; protected: std::shared_ptr Parse() const override; private: int64_t start_index_; int64_t num_samples_; }; class SubsetSampler : public Sampler { friend std::shared_ptr SelectSampler(int64_t, bool, int32_t, int32_t); public: explicit SubsetSampler(const std::vector &indices, int64_t num_samples = 0); ~SubsetSampler() = default; protected: std::shared_ptr Parse() const override; std::vector indices_; int64_t num_samples_; }; class SubsetRandomSampler final : public SubsetSampler { friend std::shared_ptr SelectSampler(int64_t, bool, int32_t, int32_t); public: explicit SubsetRandomSampler(const std::vector &indices, int64_t num_samples = 0); ~SubsetRandomSampler() = default; protected: std::shared_ptr Parse() const override; }; class WeightedRandomSampler final : public Sampler { friend std::shared_ptr SelectSampler(int64_t, bool, int32_t, int32_t); public: explicit WeightedRandomSampler(const std::vector &weights, int64_t num_samples = 0, bool replacement = true); ~WeightedRandomSampler() = default; protected: std::shared_ptr Parse() const override; private: std::vector weights_; int64_t num_samples_; bool replacement_; }; } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_