:gitee_url: https://gitee.com/mindspore/docs .. _program_listing_file_include_samplers.h: Program Listing for File samplers.h =================================== |exhale_lsh| :ref:`Return to documentation for file ` (``include/samplers.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_ #include #include #include "include/api/types.h" namespace mindspore { namespace dataset { // Forward declare class SamplerObj; // Abstract class to represent a sampler in the data pipeline. class Sampler : std::enable_shared_from_this { friend class AlbumDataset; friend class CelebADataset; friend class Cifar10Dataset; friend class Cifar100Dataset; friend class CityscapesDataset; friend class CLUEDataset; friend class CocoDataset; friend class CSVDataset; friend class DIV2KDataset; friend class FlickrDataset; friend class ImageFolderDataset; friend class ManifestDataset; friend class MindDataDataset; friend class MnistDataset; friend class RandomDataDataset; friend class SBUDataset; friend class TextFileDataset; friend class TFRecordDataset; friend class USPSDataset; friend class VOCDataset; friend std::shared_ptr SelectSampler(int64_t, bool, int32_t, int32_t); public: Sampler() {} ~Sampler() = default; virtual void AddChild(std::shared_ptr child) { children_.push_back(child); } protected: virtual std::shared_ptr Parse() const = 0; std::vector> children_; }; class DistributedSampler final : public Sampler { friend std::shared_ptr SelectSampler(int64_t, bool, int32_t, int32_t); public: DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle = true, int64_t num_samples = 0, uint32_t seed = 1, int64_t offset = -1, bool even_dist = true); ~DistributedSampler() = default; protected: std::shared_ptr Parse() const override; private: int64_t num_shards_; int64_t shard_id_; bool shuffle_; int64_t num_samples_; uint32_t seed_; int64_t offset_; bool even_dist_; }; class PKSampler final : public Sampler { friend std::shared_ptr SelectSampler(int64_t, bool, int32_t, int32_t); public: explicit PKSampler(int64_t num_val, bool shuffle = false, int64_t num_samples = 0); ~PKSampler() = default; protected: std::shared_ptr Parse() const override; private: int64_t num_val_; bool shuffle_; int64_t num_samples_; }; class RandomSampler final : public Sampler { friend std::shared_ptr SelectSampler(int64_t, bool, int32_t, int32_t); public: explicit RandomSampler(bool replacement = false, int64_t num_samples = 0); ~RandomSampler() = default; protected: std::shared_ptr Parse() const override; private: bool replacement_; int64_t num_samples_; }; class SequentialSampler final : public Sampler { friend std::shared_ptr SelectSampler(int64_t, bool, int32_t, int32_t); public: explicit SequentialSampler(int64_t start_index = 0, int64_t num_samples = 0); ~SequentialSampler() = default; protected: std::shared_ptr Parse() const override; private: int64_t start_index_; int64_t num_samples_; }; class SubsetSampler : public Sampler { friend std::shared_ptr SelectSampler(int64_t, bool, int32_t, int32_t); public: explicit SubsetSampler(std::vector indices, int64_t num_samples = 0); ~SubsetSampler() = default; protected: std::shared_ptr Parse() const override; std::vector indices_; int64_t num_samples_; }; class SubsetRandomSampler final : public SubsetSampler { friend std::shared_ptr SelectSampler(int64_t, bool, int32_t, int32_t); public: explicit SubsetRandomSampler(std::vector indices, int64_t num_samples = 0); ~SubsetRandomSampler() = default; protected: std::shared_ptr Parse() const override; }; class WeightedRandomSampler final : public Sampler { friend std::shared_ptr SelectSampler(int64_t, bool, int32_t, int32_t); public: explicit WeightedRandomSampler(std::vector weights, int64_t num_samples = 0, bool replacement = true); ~WeightedRandomSampler() = default; protected: std::shared_ptr Parse() const override; private: std::vector weights_; int64_t num_samples_; bool replacement_; }; } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_