Program Listing for File ordered_set.h
↰ Return to documentation for file (include/core/utils/ordered_set.h)
#ifndef MINDSPORE_CORE_UTILS_ORDERED_SET_H_
#define MINDSPORE_CORE_UTILS_ORDERED_SET_H_
#include <algorithm>
#include <vector>
#include <list>
#include <utility>
#include <functional>
#include <memory>
#include "utils/hashing.h"
#include "utils/hash_map.h"
namespace mindspore {
// Implementation of OrderedSet that keeps insertion order
// using map as set, and use list as a sequential container to record elements to keep insertion order
template <class T, class Hash = std::hash<T>, class KeyEqual = std::equal_to<T>>
class OrderedSet {
public:
using element_type = T;
using hasher = Hash;
using equal = KeyEqual;
using sequential_type = std::list<element_type>;
using vector_type = std::vector<element_type>;
using iterator = typename sequential_type::iterator;
using const_iterator = typename sequential_type::const_iterator;
using reverse_iterator = typename sequential_type::reverse_iterator;
using const_reverse_iterator = typename sequential_type::const_reverse_iterator;
using map_type = mindspore::HashMap<element_type, iterator, hasher, equal>;
using ordered_set_type = OrderedSet<element_type, hasher, equal>;
OrderedSet() = default;
~OrderedSet() = default;
// OrderedSet use an iterator to list as mapped value to improve the performance of insertion and deletion,
// So copy of OrderedSet should re-build value of the map key to make it pointer to the new list,, thus we use
// traversal to build elements.
OrderedSet(const OrderedSet &os) {
for (auto &item : os.ordered_data_) {
add(item);
}
}
OrderedSet(OrderedSet &&os) = default;
explicit OrderedSet(const sequential_type &other) {
for (auto &item : other) {
add(item);
}
}
// Explicitly construct an OrderedSet use vector
explicit OrderedSet(const vector_type &other) {
for (auto &item : other) {
add(item);
}
}
OrderedSet &operator=(const OrderedSet &other) {
if (this != &other) {
clear();
reserve(other.size());
for (auto &item : other.ordered_data_) {
add(item);
}
}
return *this;
}
OrderedSet &operator=(OrderedSet &&other) = default;
// insert an element to the OrderedSet after the given position.
std::pair<iterator, bool> insert(const iterator &pos, const element_type &e) {
auto result = map_.emplace(e, ordered_data_.end());
if (result.second) {
result.first->second = ordered_data_.emplace(pos, e);
}
return {result.first->second, result.second};
}
// Add an element to the OrderedSet, without judging return value
void add(const element_type &e) { (void)insert(ordered_data_.end(), e); }
// insert an element to the end of OrderedSet.
std::pair<iterator, bool> insert(const element_type &e) { return insert(ordered_data_.end(), e); }
void push_back(const element_type &e) { (void)insert(ordered_data_.end(), e); }
void push_front(const element_type &e) { (void)insert(ordered_data_.begin(), e); }
// Remove an element, if removed return true, otherwise return false
bool erase(const element_type &e) {
auto pos = map_.find(e);
if (pos == map_.end()) {
return false;
}
// erase the sequential data first
(void)ordered_data_.erase(pos->second);
(void)map_.erase(pos);
return true;
}
iterator erase(iterator pos) {
(void)map_.erase(*pos);
return ordered_data_.erase(pos);
}
iterator erase(const_iterator pos) {
(void)map_.erase(*pos);
return ordered_data_.erase(pos);
}
// Return the container size
std::size_t size() const { return map_.size(); }
bool empty() const { return map_.size() == 0; }
// Clear the elements
void clear() {
if (!map_.empty()) {
map_.clear();
ordered_data_.clear();
}
}
// Reserve memory for the number of entries.
void reserve(size_t num_entries) { map_.reserve(num_entries); }
// Compare two orderedset, if the order is not equal shall return false
bool operator==(const OrderedSet &other) const { return ordered_data_ == other.ordered_data_; }
element_type pop() {
element_type e = std::move(ordered_data_.front());
(void)map_.erase(e);
(void)ordered_data_.erase(ordered_data_.begin());
return e;
}
element_type &back() { return ordered_data_.back(); }
element_type &front() { return ordered_data_.front(); }
const element_type &back() const { return ordered_data_.back(); }
const element_type &front() const { return ordered_data_.front(); }
// Return true if there are no common elements
bool is_disjoint(const OrderedSet &other) {
for (auto &item : other.ordered_data_) {
if (map_.find(item) != map_.end()) {
return false;
}
}
return true;
}
// Test whether this is subset of other
bool is_subset(const OrderedSet &other) {
for (auto &item : ordered_data_) {
if (other.map_.find(item) == other.map_.end()) {
return false;
}
}
return true;
}
// Add elements in other to this orderedset
void update(const OrderedSet &other) {
for (auto &item : other.ordered_data_) {
add(item);
}
}
void update(const std::shared_ptr<OrderedSet> &other) { update(*other); }
void update(const sequential_type &other) {
for (auto &item : other) {
add(item);
}
}
void update(const vector_type &other) {
for (auto &item : other) {
add(item);
}
}
ordered_set_type get_union(const OrderedSet &other) {
ordered_set_type res(ordered_data_);
res.update(other);
return res;
}
// Get the union with other set, this operator may cost time because of copy
ordered_set_type operator|(const OrderedSet &other) { return get_union(other); }
// Return the intersection of two sets
ordered_set_type intersection(const OrderedSet &other) {
ordered_set_type res(ordered_data_);
for (auto &item : ordered_data_) {
if (other.map_.find(item) == other.map_.end()) {
(void)res.erase(item);
}
}
return res;
}
ordered_set_type operator&(const OrderedSet &other) { return intersection(other); }
// Return the symmetric difference of two sets
ordered_set_type symmetric_difference(const OrderedSet &other) {
ordered_set_type res(ordered_data_);
for (auto &item : other.ordered_data_) {
if (map_.find(item) != map_.end()) {
(void)res.erase(item);
} else {
res.add(item);
}
}
return res;
}
ordered_set_type operator^(const OrderedSet &other) { return symmetric_difference(other); }
// Remove elements which is also in others.
void difference_update(const OrderedSet &other) {
// use vector traversal, to keep ordrer
for (auto &item : other.ordered_data_) {
(void)erase(item);
}
}
void difference_update(const sequential_type &other) {
for (auto &item : other) {
(void)erase(item);
}
}
void difference_update(const vector_type &other) {
for (auto &item : other) {
(void)erase(item);
}
}
// Return the set with elements that are not in the others
ordered_set_type difference(const OrderedSet &other) {
ordered_set_type res(ordered_data_);
res.difference_update(other);
return res;
}
ordered_set_type operator-(const OrderedSet &other) { return difference(other); }
bool contains(const element_type &e) const { return (map_.find(e) != map_.end()); }
const_iterator find(const element_type &e) const {
auto iter = map_.find(e);
if (iter == map_.end()) {
return ordered_data_.end();
}
return iter->second;
}
iterator find(const element_type &e) {
auto iter = map_.find(e);
if (iter == map_.end()) {
return ordered_data_.end();
}
return iter->second;
}
// Return the count of an element in set
std::size_t count(const element_type &e) const { return map_.count(e); }
iterator begin() { return ordered_data_.begin(); }
iterator end() { return ordered_data_.end(); }
const_iterator begin() const { return ordered_data_.cbegin(); }
const_iterator end() const { return ordered_data_.cend(); }
const_iterator cbegin() const { return ordered_data_.cbegin(); }
const_iterator cend() const { return ordered_data_.cend(); }
private:
map_type map_;
sequential_type ordered_data_;
};
// OrderedSet that specially optimized for shared_ptr.
template <class T>
class OrderedSet<std::shared_ptr<T>> {
public:
using element_type = std::shared_ptr<T>;
using key_type = const T *;
using hash_t = PointerHash<T>;
using sequential_type = std::list<element_type>;
using vector_type = std::vector<element_type>;
using iterator = typename sequential_type::iterator;
using const_iterator = typename sequential_type::const_iterator;
using reverse_iterator = typename sequential_type::reverse_iterator;
using const_reverse_iterator = typename sequential_type::const_reverse_iterator;
using map_type = mindspore::HashMap<key_type, iterator, hash_t>;
using ordered_set_type = OrderedSet<std::shared_ptr<T>>;
OrderedSet() = default;
~OrderedSet() = default;
OrderedSet(const OrderedSet &os) {
for (auto &item : os.ordered_data_) {
add(item);
}
}
OrderedSet(OrderedSet &&os) = default;
explicit OrderedSet(const sequential_type &other) {
reserve(other.size());
for (auto &item : other) {
add(item);
}
}
explicit OrderedSet(const vector_type &other) {
reserve(other.size());
for (auto &item : other) {
add(item);
}
}
OrderedSet &operator=(const OrderedSet &other) {
if (this != &other) {
clear();
reserve(other.size());
for (auto &item : other.ordered_data_) {
add(item);
}
}
return *this;
}
OrderedSet &operator=(OrderedSet &&other) = default;
std::pair<iterator, bool> insert(const iterator &pos, const element_type &e) {
auto [map_iter, inserted] = map_.emplace(e.get(), iterator{});
if (inserted) {
map_iter->second = ordered_data_.emplace(pos, e);
}
return {map_iter->second, inserted};
}
std::pair<iterator, bool> insert(const iterator &pos, element_type &&e) {
auto [map_iter, inserted] = map_.emplace(e.get(), iterator{});
if (inserted) {
map_iter->second = ordered_data_.emplace(pos, std::move(e));
}
return {map_iter->second, inserted};
}
void add(const element_type &e) { (void)insert(ordered_data_.end(), e); }
void add(element_type &&e) { (void)insert(ordered_data_.end(), std::move(e)); }
std::pair<iterator, bool> insert(const element_type &e) { return insert(ordered_data_.end(), e); }
std::pair<iterator, bool> insert(element_type &&e) { return insert(ordered_data_.end(), std::move(e)); }
void push_back(const element_type &e) { (void)insert(ordered_data_.end(), e); }
void push_front(const element_type &e) { (void)insert(ordered_data_.begin(), e); }
bool erase(const element_type &e) {
auto pos = map_.find(e.get());
if (pos == map_.end()) {
return false;
}
auto iter = pos->second;
(void)map_.erase(pos);
(void)ordered_data_.erase(iter);
return true;
}
iterator erase(const iterator &pos) {
(void)map_.erase(pos->get());
return ordered_data_.erase(pos);
}
iterator erase(const_iterator pos) {
(void)map_.erase(pos->get());
return ordered_data_.erase(pos);
}
std::size_t size() const { return ordered_data_.size(); }
bool empty() const { return ordered_data_.empty(); }
void clear() {
if (!map_.empty()) {
map_.clear();
ordered_data_.clear();
}
}
void reserve(size_t num_entries) { map_.reserve(num_entries); }
bool operator==(const OrderedSet &other) const { return ordered_data_ == other.ordered_data_; }
element_type pop() {
element_type e = std::move(ordered_data_.front());
(void)map_.erase(e.get());
(void)ordered_data_.erase(ordered_data_.begin());
return e;
}
element_type &back() { return ordered_data_.back(); }
element_type &front() { return ordered_data_.front(); }
const element_type &back() const { return ordered_data_.back(); }
const element_type &front() const { return ordered_data_.front(); }
// Return true if there are no common elements.
bool is_disjoint(const OrderedSet &other) {
return std::all_of(begin(), end(), [&other](const auto &e) { return !other.contains(e); });
}
// Test whether this is subset of other.
bool is_subset(const OrderedSet &other) {
return std::all_of(begin(), end(), [&other](const auto &e) { return other.contains(e); });
}
// Add elements in other to this orderedset.
void update(const OrderedSet &other) {
for (auto &item : other.ordered_data_) {
add(item);
}
}
void update(const std::shared_ptr<OrderedSet> &other) { update(*other); }
void update(const sequential_type &other) {
for (auto &item : other) {
add(item);
}
}
void update(const vector_type &other) {
for (auto &item : other) {
add(item);
}
}
ordered_set_type get_union(const OrderedSet &other) {
ordered_set_type res(ordered_data_);
res.update(other);
return res;
}
// Get the union with other set, this operator may cost time because of copy.
ordered_set_type operator|(const OrderedSet &other) { return get_union(other); }
// Return the intersection of two sets.
ordered_set_type intersection(const OrderedSet &other) {
ordered_set_type res;
for (auto &item : ordered_data_) {
if (other.contains(item)) {
res.add(item);
}
}
return res;
}
ordered_set_type operator&(const OrderedSet &other) { return intersection(other); }
// Return the symmetric difference of two sets.
ordered_set_type symmetric_difference(const OrderedSet &other) {
ordered_set_type res(ordered_data_);
for (auto &item : other) {
if (contains(item)) {
(void)res.erase(item);
} else {
res.add(item);
}
}
return res;
}
ordered_set_type operator^(const OrderedSet &other) { return symmetric_difference(other); }
// Remove elements which is also in others.
void difference_update(const OrderedSet &other) {
for (auto &item : other) {
(void)erase(item);
}
}
void difference_update(const sequential_type &other) {
for (auto &item : other) {
(void)erase(item);
}
}
void difference_update(const vector_type &other) {
for (auto &item : other) {
(void)erase(item);
}
}
// Return the set with elements that are not in the others.
ordered_set_type difference(const OrderedSet &other) {
ordered_set_type res;
for (auto &item : ordered_data_) {
if (!other.contains(item)) {
res.add(item);
}
}
return res;
}
ordered_set_type operator-(const OrderedSet &other) { return difference(other); }
bool contains(const element_type &e) const { return (map_.find(e.get()) != map_.end()); }
const_iterator find(const element_type &e) const {
auto iter = map_.find(e.get());
if (iter == map_.end()) {
return ordered_data_.end();
}
return iter->second;
}
iterator find(const element_type &e) {
auto iter = map_.find(e.get());
if (iter == map_.end()) {
return ordered_data_.end();
}
return iter->second;
}
std::size_t count(const element_type &e) const { return map_.count(e.get()); }
iterator begin() { return ordered_data_.begin(); }
iterator end() { return ordered_data_.end(); }
const_iterator begin() const { return ordered_data_.cbegin(); }
const_iterator end() const { return ordered_data_.cend(); }
const_iterator cbegin() const { return ordered_data_.cbegin(); }
const_iterator cend() const { return ordered_data_.cend(); }
private:
map_type map_;
sequential_type ordered_data_;
};
} // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_ORDERED_SET_H_