Program Listing for File abstract_value.h
↰ Return to documentation for file (include/converter/include/core/abstract/abstract_value.h
)
#ifndef MINDSPORE_CORE_ABSTRACT_ABSTRACT_VALUE_H_
#define MINDSPORE_CORE_ABSTRACT_ABSTRACT_VALUE_H_
#include <cstdint>
#include <utility>
#include <vector>
#include <string>
#include <memory>
#include "utils/log_adapter.h"
#include "utils/hashing.h"
#include "utils/any.h"
#include "utils/hash_map.h"
#include "base/base.h"
#include "ir/dtype.h"
#include "ir/value.h"
#include "ir/tensor.h"
#include "abstract/dshape.h"
#include "abstract/utils.h"
#include "utils/shape_utils.h"
namespace mindspore {
namespace abstract {
class AbstractBase;
using AbstractBasePtrList = std::vector<AbstractBasePtr>;
class MS_CORE_API AbstractBase : public Base {
public:
using TraceNodeProvider = std::function<void(AnfNodePtr *node)>;
explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType,
const BaseShapePtr &shape = kNoShape)
: value_(value), type_(type), shape_(shape) {}
~AbstractBase() override = default;
MS_DECLARE_PARENT(AbstractBase, Base)
std::size_t hash() const override { return tid(); }
std::string ToString() const override;
virtual std::string ToString(bool verbose) const;
virtual bool operator==(const AbstractBase &other) const;
void set_value(const ValuePtr &value) {
MS_EXCEPTION_IF_NULL(value);
value_ = value;
}
void set_type(const TypePtr &type) {
MS_EXCEPTION_IF_NULL(type);
type_ = type;
}
virtual void set_shape(const BaseShapePtr &shape) {
MS_EXCEPTION_IF_NULL(shape);
shape_ = shape;
}
void set_value_desc(const std::string &desc) { value_desc_ = desc; }
const std::string &value_desc() const { return value_desc_; }
const ValuePtr &GetValueTrack() const { return value_; }
const TypePtr &GetTypeTrack() const { return type_; }
const BaseShapePtr &GetShapeTrack() const { return shape_; }
ValuePtr BuildValue() const;
virtual TypePtr BuildType() const = 0;
virtual BaseShapePtr BuildShape() const { return kNoShape; }
virtual AbstractBasePtr Clone() const = 0;
static void set_trace_node_provider(const TraceNodeProvider &trace_node_provider) {
trace_node_provider_ = trace_node_provider;
}
static TraceNodeProvider trace_node_provider_;
virtual AbstractBasePtr Broaden() const;
virtual AbstractBasePtr Join(const AbstractBasePtr &other) { return shared_from_base<AbstractBase>(); }
bool IsBroaden() const { return value_ == kAnyValue; }
friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<AbstractBase> &a) {
os << a->ToString();
return os;
}
virtual AbstractBasePtr PartialBroaden() const;
bool value_mutable() const { return value_mutable_; }
void set_value_mutable(bool value_mutable) { value_mutable_ = value_mutable; }
using InterpretBoolChecker = std::pair<bool, bool> (*)(const AbstractBasePtr &cond);
static inline InterpretBoolChecker interpret_bool_checker_ = nullptr;
static void set_interpret_bool_checker(InterpretBoolChecker checker) { interpret_bool_checker_ = checker; }
static inline InterpretBoolChecker interpret_bool_checker() { return interpret_bool_checker_; }
std::string name() const { return name_; }
void set_name(const std::string &name) { name_ = name; }
protected:
virtual ValuePtr RealBuildValue() const { return kAnyValue; }
std::string name_;
private:
ValuePtr value_;
TypePtr type_;
BaseShapePtr shape_;
std::string value_desc_; // store initial value description for error report
bool value_mutable_{false};
};
class MS_CORE_API AbstractScalar final : public AbstractBase {
public:
AbstractScalar() : AbstractBase(kAnyValue, kAnyType) {}
AbstractScalar(const ValuePtr &value, const TypePtr &type) : AbstractBase(value, type) {}
explicit AbstractScalar(const ValuePtr &value) : AbstractBase(value, value->type()) {}
explicit AbstractScalar(int value) : AbstractBase(MakeValue(value), kInt32) {}
explicit AbstractScalar(int64_t value) : AbstractBase(MakeValue(value), kInt64) {}
explicit AbstractScalar(float value) : AbstractBase(MakeValue(value), kFloat32) {}
explicit AbstractScalar(double value) : AbstractBase(MakeValue(value), kFloat64) {}
explicit AbstractScalar(bool value) : AbstractBase(MakeValue(value), kBool) {}
explicit AbstractScalar(const std::string &value) : AbstractBase(MakeValue(value), kString) {}
explicit AbstractScalar(const TypePtr &type) : AbstractBase(kAnyValue, type) {}
~AbstractScalar() override = default;
MS_DECLARE_PARENT(AbstractScalar, AbstractBase)
std::size_t hash() const override { return hash_combine({tid(), GetValueTrack()->hash(), GetTypeTrack()->hash()}); }
TypePtr BuildType() const override { return GetTypeTrack(); }
AbstractBasePtr Clone() const override {
return std::make_shared<AbstractScalar>(GetValueTrack(), GetTypeTrack()->Clone());
}
AbstractBasePtr Broaden() const override;
AbstractBasePtr Join(const AbstractBasePtr &other) override;
};
using AbstractScalarPtr = std::shared_ptr<AbstractScalar>;
class MS_CORE_API AbstractType final : public AbstractBase {
public:
explicit AbstractType(const TypePtr &type) : AbstractBase(type, kTypeType) {
if (type == nullptr) {
MS_LOG(EXCEPTION) << "type is nullptr";
}
}
~AbstractType() override = default;
MS_DECLARE_PARENT(AbstractType, AbstractBase)
std::string ToString() const override;
bool operator==(const AbstractBase &other) const override;
TypePtr BuildType() const override { return std::make_shared<TypeType>(); }
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override { return Clone(); }
};
using AbstractTypePtr = std::shared_ptr<AbstractType>;
class MS_CORE_API AbstractError final : public AbstractBase {
public:
AbstractError(const StringImmPtr &err, const AnfNodePtr &node) : AbstractBase(err), node_(node) {
if (err == nullptr || node == nullptr) {
MS_LOG(EXCEPTION) << "err or node is nullptr";
}
}
~AbstractError() override = default;
MS_DECLARE_PARENT(AbstractError, AbstractBase)
TypePtr BuildType() const override { return std::make_shared<Problem>(); }
AbstractBasePtr Broaden() const override { return Clone(); }
AbstractBasePtr Clone() const override {
return std::make_shared<AbstractError>(GetValueTrack()->cast<StringImmPtr>(), node_);
}
std::string ToString() const override;
private:
// Origin node been specialized to AbstractError, for debug purpose only.
const AnfNodePtr node_;
};
class MS_CORE_API AbstractScript final : public AbstractBase {
public:
AbstractScript() : AbstractBase(kAnyValue, kAnyType) {}
AbstractScript(const ValuePtr &value, const TypePtr &type) : AbstractBase(value, type) {}
explicit AbstractScript(const ValuePtr &value) : AbstractBase(value, kString) {}
~AbstractScript() override = default;
MS_DECLARE_PARENT(AbstractScript, AbstractBase)
std::size_t hash() const override { return hash_combine({tid(), GetValueTrack()->hash(), GetTypeTrack()->hash()}); }
TypePtr BuildType() const override { return GetTypeTrack(); }
AbstractBasePtr Clone() const override {
return std::make_shared<AbstractScript>(GetValueTrack(), GetTypeTrack()->Clone());
}
AbstractBasePtr Broaden() const override { return Clone(); }
};
using AbstractScriptPtr = std::shared_ptr<AbstractScript>;
class Evaluator;
using EvaluatorPtr = std::shared_ptr<Evaluator>;
class AnalysisEngine;
using AnalysisEnginePtr = std::shared_ptr<AnalysisEngine>;
class AbstractFunction;
using AbstractFunctionPtr = std::shared_ptr<AbstractFunction>;
class AbstractFuncAtom;
using AbstractFuncAtomPtr = std::shared_ptr<AbstractFuncAtom>;
using AbstractFuncAtomPtrList = std::vector<AbstractFuncAtomPtr>;
class MS_CORE_API AbstractFunction : public AbstractBase {
public:
AbstractFunction() = default;
~AbstractFunction() override = default;
MS_DECLARE_PARENT(AbstractFunction, AbstractBase)
virtual AbstractFunctionPtr GetUnique() = 0;
TypePtr BuildType() const override { return std::make_shared<Function>(); }
AbstractBasePtr Clone() const override { return Copy(); }
AbstractBasePtr Broaden() const override {
return const_cast<AbstractFunction *>(this)->shared_from_base<AbstractFunction>();
}
virtual AbstractFunctionPtr Copy() const = 0;
AbstractBasePtr Join(const AbstractBasePtr &other) final;
virtual AbstractFunctionPtr Join(const AbstractFunctionPtr &other) = 0;
virtual void Visit(std::function<void(const AbstractFuncAtomPtr &)>) const = 0;
bool operator==(const AbstractBase &other) const final;
virtual bool operator==(const AbstractFunction &other) const = 0;
static AbstractFunctionPtr MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list);
virtual std::uintptr_t tracking_id() const { return 0; }
virtual AbstractFunctionPtr CopyWithoutTrackingId() const { return Copy(); }
virtual AnalysisContextPtr context() const { return nullptr; }
static std::uintptr_t ToTrackingId(const AnfNodePtr &node) { return reinterpret_cast<std::uintptr_t>(node.get()); }
};
using AbstractFunctionPtrList = std::vector<AbstractFunctionPtr>;
class MS_CORE_API AbstractKeywordArg final : public AbstractBase {
public:
AbstractKeywordArg(const std::string &key, const AbstractBasePtr &argument) : arg_name_(key), arg_value_(argument) {}
~AbstractKeywordArg() override = default;
MS_DECLARE_PARENT(AbstractKeywordArg, AbstractBase)
TypePtr BuildType() const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
std::size_t hash() const override;
bool operator==(const AbstractKeywordArg &other) const;
bool operator==(const AbstractBase &other) const override;
std::string get_key() const { return arg_name_; }
AbstractBasePtr get_arg() const { return arg_value_; }
std::string ToString() const override;
protected:
ValuePtr RealBuildValue() const override;
private:
std::string arg_name_;
AbstractBasePtr arg_value_;
};
using AbstractKeywordArgPtr = std::shared_ptr<AbstractKeywordArg>;
class MS_CORE_API AbstractUndetermined : public AbstractBase {
public:
AbstractUndetermined() : AbstractBase(kAnyValue) {}
explicit AbstractUndetermined(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
: AbstractBase(kAnyValue), element_(element) {
if (element == nullptr) {
MS_LOG(EXCEPTION) << "element is nullptr";
}
if (element->isa<AbstractUndetermined>()) {
MS_LOG(EXCEPTION) << "element type error";
}
MS_EXCEPTION_IF_NULL(shape);
if (shape->isa<NoShape>()) {
MS_LOG(EXCEPTION) << "AbstractUndetermined can't set shape as NoShape.";
}
AbstractBase::set_shape(shape);
}
AbstractUndetermined(const TypePtr &element_type, const ShapeVector &shape)
: AbstractBase(kAnyValue), element_(std::make_shared<AbstractScalar>(kAnyValue, element_type)) {
if (element_type == nullptr) {
MS_LOG(EXCEPTION) << "element_type is nullptr";
}
AbstractBase::set_shape(std::make_shared<Shape>(shape));
}
explicit AbstractUndetermined(const TypePtr &element_type, const BaseShapePtr &shape = std::make_shared<Shape>())
: AbstractBase(kAnyValue), element_(std::make_shared<AbstractScalar>(kAnyValue, element_type)) {
if (element_type == nullptr) {
MS_LOG(EXCEPTION) << "element_type is nullptr";
}
MS_EXCEPTION_IF_NULL(shape);
if (shape->isa<NoShape>()) {
MS_LOG(EXCEPTION) << "AbstractUndetermined can't set shape as NoShape.";
}
AbstractBase::set_shape(shape);
}
~AbstractUndetermined() override = default;
MS_DECLARE_PARENT(AbstractUndetermined, AbstractBase)
TypePtr BuildType() const override { return std::make_shared<UndeterminedType>(); }
AbstractBasePtr Clone() const override { return std::make_shared<AbstractUndetermined>(); }
AbstractBasePtr element() const { return element_; }
ShapePtr shape() const;
void set_shape(const BaseShapePtr &shape) override;
protected:
AbstractBasePtr element_;
};
class MS_CORE_API AbstractTensor : public AbstractUndetermined {
public:
explicit AbstractTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
: AbstractUndetermined(element, shape) {}
AbstractTensor(const TypePtr &element_type, const ShapeVector &shape) : AbstractUndetermined(element_type, shape) {}
explicit AbstractTensor(const tensor::TensorPtr &tensor) : AbstractUndetermined(tensor->Dtype(), tensor->shape()) {}
explicit AbstractTensor(const TypePtr &element_type, const BaseShapePtr &shape = std::make_shared<Shape>())
: AbstractUndetermined(element_type, shape) {}
~AbstractTensor() override = default;
MS_DECLARE_PARENT(AbstractTensor, AbstractUndetermined)
void set_value_range(const ValuePtr &min_value, const ValuePtr &max_value) {
min_value_ = min_value;
max_value_ = max_value;
}
const ValuePtr &get_min_value() const { return min_value_; }
const ValuePtr &get_max_value() const { return max_value_; }
void set_shape_value(const ValuePtr &shape_value) { shape_value_ = shape_value; }
const ValuePtr &get_shape_value() const { return shape_value_; }
TypePtr BuildType() const override;
BaseShapePtr BuildShape() const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr BroadenWithShape() const;
AbstractBasePtr Join(const AbstractBasePtr &other) override;
virtual bool operator==(const AbstractTensor &other) const;
bool operator==(const AbstractBase &other) const override;
std::string ToString() const override;
std::size_t hash() const override {
// We have to exclude value pointer from hash, because CSE (Common Subexpression Elimination)
// will use this hash to find duplicate ValueNodes that Tensor values are equal.
auto hash_sum = hash_combine(tid(), element_->hash());
const auto &shape = GetShapeTrack();
if (shape != nullptr) {
hash_sum = hash_combine(hash_sum, shape->hash());
}
return hash_sum;
}
AbstractBasePtr PartialBroaden() const override;
protected:
bool equal_to(const AbstractTensor &other) const;
ValuePtr min_value_ = nullptr;
ValuePtr max_value_ = nullptr;
ValuePtr shape_value_ = nullptr;
};
using AbstractTensorPtr = std::shared_ptr<AbstractTensor>;
using AbstractTensorPtrList = std::vector<AbstractTensorPtr>;
class MS_CORE_API AbstractSequence : public AbstractBase {
public:
explicit AbstractSequence(AbstractBasePtrList &&elements, const std::shared_ptr<AnfNodeWeakPtrList> &sequence_nodes);
explicit AbstractSequence(const AbstractBasePtrList &elements,
const std::shared_ptr<AnfNodeWeakPtrList> &sequence_nodes);
~AbstractSequence() override = default;
MS_DECLARE_PARENT(AbstractSequence, AbstractBase)
TypePtrList ElementsType() const;
BaseShapePtrList ElementsShape() const;
AbstractBasePtrList ElementsClone() const;
AbstractBasePtrList ElementsBroaden() const;
AbstractBasePtrList ElementsPartialBroaden() const;
template <typename T>
ValuePtr ElementsBuildValue() const;
template <typename T>
AbstractBasePtr ElementsJoin(const AbstractBasePtr &other);
AnfNodeWeakPtrList SequenceNodesJoin(const AbstractBasePtr &other);
std::size_t size() const { return elements_.size(); }
bool empty() const { return elements_.empty(); }
const AbstractBasePtrList &elements() const { return elements_; }
bool PurifyElements();
const std::shared_ptr<AnfNodeWeakPtrList> &sequence_nodes() const { return sequence_nodes_; }
void set_sequence_nodes(const std::shared_ptr<AnfNodeWeakPtrList> &sequence_nodes) {
sequence_nodes_ = sequence_nodes;
}
void InsertSequenceNode(const AnfNodePtr &sequence_node);
void InsertSequenceNodes(const AnfNodeWeakPtrList &sequence_nodes);
void UpdateSequenceNode(const AnfNodePtr &old_sequence_node, const AnfNodePtr &new_sequence_node);
std::size_t hash() const override;
std::string ToStringInternal() const;
std::string ToString() const override;
std::string ToString(bool verbose) const override;
const AbstractBasePtr operator[](const std::size_t &dim) const;
virtual bool operator==(const AbstractSequence &other) const;
protected:
AbstractBasePtrList elements_;
// Since there're not too many nodes, we just use vector here.
std::shared_ptr<AnfNodeWeakPtrList> sequence_nodes_;
};
using AbstractSequencePtr = std::shared_ptr<AbstractSequence>;
class MS_CORE_API AbstractTuple : public AbstractSequence {
public:
explicit AbstractTuple(AbstractBasePtrList &&elements,
const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes = nullptr)
: AbstractSequence(std::move(elements), tuple_nodes) {}
explicit AbstractTuple(const AbstractBasePtrList &elements,
const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes = nullptr)
: AbstractSequence(elements, tuple_nodes) {}
~AbstractTuple() override = default;
MS_DECLARE_PARENT(AbstractTuple, AbstractSequence)
void set_shape(const BaseShapePtr &shape) override;
TypePtr BuildType() const override { return std::make_shared<Tuple>(ElementsType()); }
BaseShapePtr BuildShape() const override { return std::make_shared<TupleShape>(ElementsShape()); }
AbstractBasePtr Clone() const override { return std::make_shared<AbstractTuple>(ElementsClone(), sequence_nodes()); }
AbstractBasePtr Broaden() const override {
return std::make_shared<AbstractTuple>(ElementsBroaden(), sequence_nodes());
}
AbstractBasePtr PartialBroaden() const override {
return std::make_shared<AbstractTuple>(ElementsPartialBroaden(), sequence_nodes());
}
AbstractBasePtr Join(const AbstractBasePtr &other) override {
auto res = dyn_cast<AbstractSequence>(ElementsJoin<AbstractTuple>(other));
MS_EXCEPTION_IF_NULL(res);
res->InsertSequenceNodes(SequenceNodesJoin(other));
return res;
}
bool ContainsAllBroadenTensors() const;
bool operator==(const AbstractTuple &other) const;
bool operator==(const AbstractBase &other) const override;
protected:
ValuePtr RealBuildValue() const override { return ElementsBuildValue<ValueTuple>(); }
};
using AbstractTuplePtr = std::shared_ptr<AbstractTuple>;
class MS_CORE_API AbstractList final : public AbstractSequence {
public:
explicit AbstractList(AbstractBasePtrList &&elements, const std::shared_ptr<AnfNodeWeakPtrList> &list_nodes = nullptr)
: AbstractSequence(std::move(elements), list_nodes) {}
explicit AbstractList(const AbstractBasePtrList &elements,
const std::shared_ptr<AnfNodeWeakPtrList> &list_nodes = nullptr)
: AbstractSequence(elements, list_nodes) {}
~AbstractList() override = default;
MS_DECLARE_PARENT(AbstractList, AbstractSequence)
TypePtr BuildType() const override { return std::make_shared<List>(ElementsType()); }
BaseShapePtr BuildShape() const override { return std::make_shared<ListShape>(ElementsShape()); }
AbstractBasePtr Clone() const override { return std::make_shared<AbstractList>(ElementsClone(), sequence_nodes()); }
AbstractBasePtr Broaden() const override {
return std::make_shared<AbstractList>(ElementsBroaden(), sequence_nodes());
}
AbstractBasePtr PartialBroaden() const override {
return std::make_shared<AbstractList>(ElementsPartialBroaden(), sequence_nodes());
}
AbstractBasePtr Join(const AbstractBasePtr &other) override {
auto res = dyn_cast<AbstractSequence>(ElementsJoin<AbstractList>(other));
MS_EXCEPTION_IF_NULL(res);
res->InsertSequenceNodes(SequenceNodesJoin(other));
return res;
}
bool operator==(const AbstractList &other) const;
bool operator==(const AbstractBase &other) const override;
protected:
ValuePtr RealBuildValue() const override { return ElementsBuildValue<ValueList>(); }
};
using AbstractListPtr = std::shared_ptr<AbstractList>;
class MS_CORE_API AbstractDictionary final : public AbstractBase {
public:
explicit AbstractDictionary(const std::vector<AbstractAttribute> &key_values) : key_values_(key_values) {}
~AbstractDictionary() override = default;
MS_DECLARE_PARENT(AbstractDictionary, AbstractBase)
TypePtr BuildType() const override;
bool operator==(const AbstractDictionary &other) const;
bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
std::string ToString() const override;
std::size_t hash() const override;
std::size_t size() const { return key_values_.size(); }
const std::vector<AbstractAttribute> &elements() const { return key_values_; }
protected:
ValuePtr RealBuildValue() const override;
std::vector<AbstractAttribute> key_values_;
};
using AbstractDictionaryPtr = std::shared_ptr<AbstractDictionary>;
class MS_CORE_API AbstractSlice final : public AbstractBase {
public:
AbstractSlice(const AbstractBasePtr &start, const AbstractBasePtr &stop, const AbstractBasePtr &step)
: start_(start), stop_(stop), step_(step) {}
~AbstractSlice() override = default;
MS_DECLARE_PARENT(AbstractSlice, AbstractBase)
TypePtr BuildType() const override;
bool operator==(const AbstractSlice &other) const;
bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
std::string ToString() const override;
std::size_t hash() const override;
AbstractBasePtr start() const { return start_; }
AbstractBasePtr stop() const { return stop_; }
AbstractBasePtr step() const { return step_; }
protected:
ValuePtr RealBuildValue() const override;
private:
AbstractBasePtr start_;
AbstractBasePtr stop_;
AbstractBasePtr step_;
};
using AbstractSlicePtr = std::shared_ptr<AbstractSlice>;
class MS_CORE_API AbstractJTagged final : public AbstractBase {
public:
explicit AbstractJTagged(const AbstractBasePtr &element) : element_(element) {}
~AbstractJTagged() override = default;
MS_DECLARE_PARENT(AbstractJTagged, AbstractBase)
TypePtr BuildType() const override;
AbstractBasePtr Clone() const override { return std::make_shared<AbstractJTagged>(element_->Clone()); }
AbstractBasePtr Broaden() const override { return std::make_shared<AbstractJTagged>(element_->Broaden()); }
AbstractBasePtr Join(const AbstractBasePtr &other) override;
bool operator==(const AbstractJTagged &other) const;
bool operator==(const AbstractBase &other) const override;
std::string ToString() const override;
AbstractBasePtr element() { return element_; }
std::size_t hash() const override { return hash_combine(tid(), element_->hash()); }
private:
AbstractBasePtr element_;
};
using AbstractJTaggedPtr = std::shared_ptr<AbstractJTagged>;
class MS_CORE_API AbstractNone final : public AbstractBase {
public:
AbstractNone() : AbstractBase() { set_type(std::make_shared<TypeNone>()); }
~AbstractNone() override = default;
MS_DECLARE_PARENT(AbstractNone, AbstractBase)
TypePtr BuildType() const override { return std::make_shared<TypeNone>(); }
bool operator==(const AbstractNone &) const;
bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override { return std::make_shared<AbstractNone>(); }
std::string ToString() const override;
protected:
ValuePtr RealBuildValue() const override;
};
using AbstractNonePtr = std::shared_ptr<AbstractNone>;
class MS_CORE_API AbstractNull final : public AbstractBase {
public:
AbstractNull() : AbstractBase(kNull) { set_type(std::make_shared<TypeNull>()); }
~AbstractNull() override = default;
MS_DECLARE_PARENT(AbstractNull, AbstractBase)
TypePtr BuildType() const override { return std::make_shared<TypeNull>(); }
bool operator==(const AbstractNull &) const;
bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override { return std::make_shared<AbstractNull>(); }
std::string ToString() const override;
};
using AbstractNullPtr = std::shared_ptr<AbstractNull>;
class MS_CORE_API AbstractTimeOut final : public AbstractBase {
public:
AbstractTimeOut() : AbstractBase(kNull) { set_type(std::make_shared<TypeNull>()); }
~AbstractTimeOut() override = default;
MS_DECLARE_PARENT(AbstractTimeOut, AbstractBase)
TypePtr BuildType() const override { return std::make_shared<TypeNull>(); }
bool operator==(const AbstractTimeOut &) const;
bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override { return std::make_shared<AbstractTimeOut>(); }
std::string ToString() const override;
};
using AbstractTimeOutPtr = std::shared_ptr<AbstractTimeOut>;
class MS_CORE_API AbstractEllipsis final : public AbstractBase {
public:
AbstractEllipsis() : AbstractBase(kEllipsis) { set_type(std::make_shared<TypeEllipsis>()); }
~AbstractEllipsis() override = default;
MS_DECLARE_PARENT(AbstractEllipsis, AbstractBase)
TypePtr BuildType() const override { return std::make_shared<TypeEllipsis>(); }
bool operator==(const AbstractEllipsis &) const;
bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override { return std::make_shared<AbstractEllipsis>(); }
std::string ToString() const override;
};
using AbstractEllipsisPtr = std::shared_ptr<AbstractEllipsis>;
class MS_CORE_API AbstractRefTensor final : public AbstractTensor {
public:
AbstractRefTensor(const AbstractTensorPtr &ref_value, const ValuePtr &ref_key_value);
~AbstractRefTensor() override = default;
MS_DECLARE_PARENT(AbstractRefTensor, AbstractTensor)
TypePtr BuildType() const override;
bool operator==(const AbstractRefTensor &other) const;
bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr CloneAsTensor() const { return AbstractTensor::Clone(); }
std::string ToString() const override;
inline AbstractTensorPtr ref() { return shared_from_base<AbstractTensor>(); }
inline ValuePtr ref_key_value() const { return ref_key_value_; }
AbstractBasePtr Broaden() const override;
virtual AbstractBasePtr Join(const std::shared_ptr<AbstractRefTensor> &other);
AbstractBasePtr Join(const AbstractBasePtr &other) override;
AbstractBasePtr PartialBroaden() const override;
private:
// ref_key_value is the reference key of AbstractRef, the value can be a string value or kAnyValue
ValuePtr ref_key_value_;
};
using AbstractRefPtr = std::shared_ptr<AbstractRefTensor>;
MS_CORE_API std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list);
MS_CORE_API bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs);
struct AbstractBasePtrListHasher {
std::size_t operator()(const AbstractBasePtrList &args_spec_list) const {
return AbstractBasePtrListHash(args_spec_list);
}
};
struct AbstractBasePtrListEqual {
bool operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const {
return AbstractBasePtrListDeepEqual(lhs, rhs);
}
};
class MS_CORE_API AbstractSparseTensor : public AbstractTuple {
public:
explicit AbstractSparseTensor(AbstractBasePtrList &&elements,
const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes = nullptr)
: AbstractTuple(std::move(elements), tuple_nodes) {}
explicit AbstractSparseTensor(const AbstractBasePtrList &elements,
const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes = nullptr)
: AbstractTuple(elements, tuple_nodes) {}
~AbstractSparseTensor() override = default;
MS_DECLARE_PARENT(AbstractSparseTensor, AbstractTuple)
template <typename T>
const T GetAbsPtrAt(size_t index) const;
BaseShapePtrList ElementsShapeTupleRecursive() const;
TypePtr BuildType() const override;
BaseShapePtr BuildShape() const override { return std::make_shared<TupleShape>(ElementsShapeTupleRecursive()); }
const TypeId GetTensorTypeIdAt(size_t index) const;
const TypeId GetShapeTypeIdAt(size_t index) const;
const AbstractTuplePtr shape() const;
};
using AbstractSparseTensorPtr = std::shared_ptr<AbstractSparseTensor>;
class MS_CORE_API AbstractRowTensor final : public AbstractUndetermined {
public:
explicit AbstractRowTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
: AbstractUndetermined(element, shape) {}
AbstractRowTensor(const TypePtr &element_type, const ShapeVector &shape)
: AbstractUndetermined(element_type, shape) {}
~AbstractRowTensor() override = default;
MS_DECLARE_PARENT(AbstractRowTensor, AbstractUndetermined)
const AbstractTensorPtr indices() const { return indices_; }
void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; }
const AbstractTensorPtr values() const { return values_; }
void set_values(const AbstractTensorPtr &values) { values_ = values; }
const AbstractTuplePtr dense_shape() const { return dense_shape_; }
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
TypePtr BuildType() const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr BroadenWithShape() const;
std::string ToString() const override;
private:
std::shared_ptr<AbstractRowTensor> MakeAbstract(const BaseShapePtr &shp) const;
AbstractTensorPtr indices_;
AbstractTensorPtr values_;
AbstractTuplePtr dense_shape_;
};
using AbstractRowTensorPtr = std::shared_ptr<AbstractRowTensor>;
// COOTensor is a Tuple with fixed number of elements and specific meaning of each position.
class MS_CORE_API AbstractCOOTensor : public AbstractSparseTensor {
public:
explicit AbstractCOOTensor(AbstractBasePtrList &&elements,
const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes = nullptr)
: AbstractSparseTensor(std::move(elements), tuple_nodes) {}
explicit AbstractCOOTensor(const AbstractBasePtrList &elements,
const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes = nullptr)
: AbstractSparseTensor(elements, tuple_nodes) {}
~AbstractCOOTensor() override = default;
MS_DECLARE_PARENT(AbstractCOOTensor, AbstractSparseTensor)
const AbstractTensorPtr indices() const;
const AbstractTensorPtr values() const;
TypePtr BuildType() const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr PartialBroaden() const override;
std::string ToString() const override;
static constexpr size_t kIndicesIdx = 0;
static constexpr size_t kValuesIdx = 1;
};
using AbstractCOOTensorPtr = std::shared_ptr<AbstractCOOTensor>;
// CSRTensor is a Tuple with fixed number of elements and specific meaning of each position.
class MS_CORE_API AbstractCSRTensor : public AbstractSparseTensor {
public:
explicit AbstractCSRTensor(AbstractBasePtrList &&elements,
const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes = nullptr)
: AbstractSparseTensor(std::move(elements), tuple_nodes) {}
explicit AbstractCSRTensor(const AbstractBasePtrList &elements,
const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes = nullptr)
: AbstractSparseTensor(elements, tuple_nodes) {}
~AbstractCSRTensor() override = default;
MS_DECLARE_PARENT(AbstractCSRTensor, AbstractSparseTensor)
const AbstractTensorPtr indptr() const;
const AbstractTensorPtr indices() const;
const AbstractTensorPtr values() const;
TypePtr BuildType() const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr PartialBroaden() const override;
std::string ToString() const override;
static constexpr size_t kIndptrIdx = 0;
static constexpr size_t kIndicesIdx = 1;
static constexpr size_t kValuesIdx = 2;
};
using AbstractCSRTensorPtr = std::shared_ptr<AbstractCSRTensor>;
class MS_CORE_API AbstractMonad : public AbstractBase {
public:
~AbstractMonad() override = default;
MS_DECLARE_PARENT(AbstractMonad, AbstractBase)
std::size_t hash() const override { return hash_combine({tid()}); }
TypePtr BuildType() const override { return GetTypeTrack(); }
AbstractBasePtr Broaden() const override { return AbstractBase::Broaden(); }
AbstractBasePtr Join(const AbstractBasePtr &other) override = 0;
std::string ToString() const override {
std::ostringstream buffer;
buffer << type_name() << "(" << GetValueTrack()->ToString() << ")";
return buffer.str();
}
protected:
AbstractMonad(const ValuePtr &value, const TypePtr &type) : AbstractBase(value, type) {}
};
using AbstractMonadPtr = std::shared_ptr<AbstractMonad>;
class MS_CORE_API AbstractUMonad final : public AbstractMonad {
public:
explicit AbstractUMonad(const ValuePtr &value = kUMonad) : AbstractMonad(value, kUMonadType) {}
~AbstractUMonad() override = default;
MS_DECLARE_PARENT(AbstractUMonad, AbstractMonad)
AbstractBasePtr Clone() const override { return std::make_shared<AbstractUMonad>(GetValueTrack()); }
AbstractBasePtr Join(const AbstractBasePtr &other) override;
bool operator==(const AbstractUMonad &) const;
bool operator==(const AbstractBase &other) const override;
};
using AbstractUMonadPtr = std::shared_ptr<AbstractUMonad>;
class MS_CORE_API AbstractIOMonad final : public AbstractMonad {
public:
explicit AbstractIOMonad(const ValuePtr &value = kIOMonad) : AbstractMonad(value, kIOMonadType) {}
~AbstractIOMonad() override = default;
MS_DECLARE_PARENT(AbstractIOMonad, AbstractMonad)
AbstractBasePtr Clone() const override { return std::make_shared<AbstractIOMonad>(GetValueTrack()); }
AbstractBasePtr Join(const AbstractBasePtr &other) override;
bool operator==(const AbstractIOMonad &) const;
bool operator==(const AbstractBase &other) const override;
};
using AbstractIOMonadPtr = std::shared_ptr<AbstractIOMonad>;
MS_CORE_API std::string ExtractLoggingInfo(const std::string &info);
MS_CORE_API void SynchronizeSequenceElementsUseFlagsRecursively(const AbstractSequencePtr &lhs_sequence,
const AbstractSequencePtr &rhs_sequence);
} // namespace abstract
} // namespace mindspore
#endif // MINDSPORE_CORE_ABSTRACT_ABSTRACT_VALUE_H_