Program Listing for File abstract_value.h

Return to documentation for file (include/converter/include/core/abstract/abstract_value.h)

#ifndef MINDSPORE_CORE_ABSTRACT_ABSTRACT_VALUE_H_
#define MINDSPORE_CORE_ABSTRACT_ABSTRACT_VALUE_H_

#include <cstdint>
#include <utility>
#include <vector>
#include <string>
#include <memory>

#include "utils/log_adapter.h"
#include "utils/hashing.h"
#include "utils/any.h"
#include "utils/hash_map.h"
#include "base/base.h"
#include "ir/dtype.h"
#include "ir/value.h"
#include "ir/tensor.h"
#include "abstract/dshape.h"
#include "abstract/utils.h"
#include "utils/shape_utils.h"

namespace mindspore {
namespace abstract {
class AbstractBase;
using AbstractBasePtrList = std::vector<AbstractBasePtr>;
class MS_CORE_API AbstractBase : public Base {
 public:
  using TraceNodeProvider = std::function<void(AnfNodePtr *node)>;

  explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType,
                        const BaseShapePtr &shape = kNoShape)
      : value_(value), type_(type), shape_(shape) {}

  ~AbstractBase() override = default;
  MS_DECLARE_PARENT(AbstractBase, Base)


  std::size_t hash() const override { return tid(); }

  std::string ToString() const override;

  virtual std::string ToString(bool verbose) const;

  virtual bool operator==(const AbstractBase &other) const;

  void set_value(const ValuePtr &value) {
    MS_EXCEPTION_IF_NULL(value);
    value_ = value;
  }

  void set_type(const TypePtr &type) {
    MS_EXCEPTION_IF_NULL(type);
    type_ = type;
  }

  virtual void set_shape(const BaseShapePtr &shape) {
    MS_EXCEPTION_IF_NULL(shape);
    shape_ = shape;
  }

  void set_value_desc(const std::string &desc) { value_desc_ = desc; }

  const std::string &value_desc() const { return value_desc_; }

  const ValuePtr &GetValueTrack() const { return value_; }

  const TypePtr &GetTypeTrack() const { return type_; }

  const BaseShapePtr &GetShapeTrack() const { return shape_; }

  ValuePtr BuildValue() const;

  virtual TypePtr BuildType() const = 0;

  virtual BaseShapePtr BuildShape() const { return kNoShape; }

  virtual AbstractBasePtr Clone() const = 0;

  static void set_trace_node_provider(const TraceNodeProvider &trace_node_provider) {
    trace_node_provider_ = trace_node_provider;
  }

  static TraceNodeProvider trace_node_provider_;

  virtual AbstractBasePtr Broaden() const;

  virtual AbstractBasePtr Join(const AbstractBasePtr &other) { return shared_from_base<AbstractBase>(); }
  bool IsBroaden() const { return value_ == kAnyValue; }

  friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<AbstractBase> &a) {
    os << a->ToString();
    return os;
  }

  virtual AbstractBasePtr PartialBroaden() const;

  bool value_mutable() const { return value_mutable_; }

  void set_value_mutable(bool value_mutable) { value_mutable_ = value_mutable; }

  using InterpretBoolChecker = std::pair<bool, bool> (*)(const AbstractBasePtr &cond);
  static inline InterpretBoolChecker interpret_bool_checker_ = nullptr;
  static void set_interpret_bool_checker(InterpretBoolChecker checker) { interpret_bool_checker_ = checker; }
  static inline InterpretBoolChecker interpret_bool_checker() { return interpret_bool_checker_; }

  std::string name() const { return name_; }

  void set_name(const std::string &name) { name_ = name; }

 protected:
  virtual ValuePtr RealBuildValue() const { return kAnyValue; }
  std::string name_;

 private:
  ValuePtr value_;
  TypePtr type_;
  BaseShapePtr shape_;
  std::string value_desc_;  // store initial value description for error report
  bool value_mutable_{false};
};

class MS_CORE_API AbstractScalar final : public AbstractBase {
 public:
  AbstractScalar() : AbstractBase(kAnyValue, kAnyType) {}

  AbstractScalar(const ValuePtr &value, const TypePtr &type) : AbstractBase(value, type) {}

  explicit AbstractScalar(const ValuePtr &value) : AbstractBase(value, value->type()) {}

  explicit AbstractScalar(int value) : AbstractBase(MakeValue(value), kInt32) {}

  explicit AbstractScalar(int64_t value) : AbstractBase(MakeValue(value), kInt64) {}

  explicit AbstractScalar(float value) : AbstractBase(MakeValue(value), kFloat32) {}

  explicit AbstractScalar(double value) : AbstractBase(MakeValue(value), kFloat64) {}

  explicit AbstractScalar(bool value) : AbstractBase(MakeValue(value), kBool) {}

  explicit AbstractScalar(const std::string &value) : AbstractBase(MakeValue(value), kString) {}

  explicit AbstractScalar(const TypePtr &type) : AbstractBase(kAnyValue, type) {}

  ~AbstractScalar() override = default;
  MS_DECLARE_PARENT(AbstractScalar, AbstractBase)

  std::size_t hash() const override { return hash_combine({tid(), GetValueTrack()->hash(), GetTypeTrack()->hash()}); }

  TypePtr BuildType() const override { return GetTypeTrack(); }

  AbstractBasePtr Clone() const override {
    return std::make_shared<AbstractScalar>(GetValueTrack(), GetTypeTrack()->Clone());
  }

  AbstractBasePtr Broaden() const override;

  AbstractBasePtr Join(const AbstractBasePtr &other) override;
};
using AbstractScalarPtr = std::shared_ptr<AbstractScalar>;

class MS_CORE_API AbstractType final : public AbstractBase {
 public:
  explicit AbstractType(const TypePtr &type) : AbstractBase(type, kTypeType) {
    if (type == nullptr) {
      MS_LOG(EXCEPTION) << "type is nullptr";
    }
  }

  ~AbstractType() override = default;
  MS_DECLARE_PARENT(AbstractType, AbstractBase)

  std::string ToString() const override;

  bool operator==(const AbstractBase &other) const override;

  TypePtr BuildType() const override { return std::make_shared<TypeType>(); }

  AbstractBasePtr Clone() const override;

  AbstractBasePtr Broaden() const override { return Clone(); }
};
using AbstractTypePtr = std::shared_ptr<AbstractType>;

class MS_CORE_API AbstractError final : public AbstractBase {
 public:
  AbstractError(const StringImmPtr &err, const AnfNodePtr &node) : AbstractBase(err), node_(node) {
    if (err == nullptr || node == nullptr) {
      MS_LOG(EXCEPTION) << "err or node is nullptr";
    }
  }

  ~AbstractError() override = default;
  MS_DECLARE_PARENT(AbstractError, AbstractBase)

  TypePtr BuildType() const override { return std::make_shared<Problem>(); }

  AbstractBasePtr Broaden() const override { return Clone(); }

  AbstractBasePtr Clone() const override {
    return std::make_shared<AbstractError>(GetValueTrack()->cast<StringImmPtr>(), node_);
  }

  std::string ToString() const override;

 private:
  // Origin node been specialized to AbstractError, for debug purpose only.
  const AnfNodePtr node_;
};

class MS_CORE_API AbstractScript final : public AbstractBase {
 public:
  AbstractScript() : AbstractBase(kAnyValue, kAnyType) {}

  AbstractScript(const ValuePtr &value, const TypePtr &type) : AbstractBase(value, type) {}

  explicit AbstractScript(const ValuePtr &value) : AbstractBase(value, kString) {}

  ~AbstractScript() override = default;
  MS_DECLARE_PARENT(AbstractScript, AbstractBase)

  std::size_t hash() const override { return hash_combine({tid(), GetValueTrack()->hash(), GetTypeTrack()->hash()}); }

  TypePtr BuildType() const override { return GetTypeTrack(); }

  AbstractBasePtr Clone() const override {
    return std::make_shared<AbstractScript>(GetValueTrack(), GetTypeTrack()->Clone());
  }

  AbstractBasePtr Broaden() const override { return Clone(); }
};
using AbstractScriptPtr = std::shared_ptr<AbstractScript>;

class Evaluator;
using EvaluatorPtr = std::shared_ptr<Evaluator>;
class AnalysisEngine;
using AnalysisEnginePtr = std::shared_ptr<AnalysisEngine>;

class AbstractFunction;
using AbstractFunctionPtr = std::shared_ptr<AbstractFunction>;
class AbstractFuncAtom;
using AbstractFuncAtomPtr = std::shared_ptr<AbstractFuncAtom>;
using AbstractFuncAtomPtrList = std::vector<AbstractFuncAtomPtr>;

class MS_CORE_API AbstractFunction : public AbstractBase {
 public:
  AbstractFunction() = default;
  ~AbstractFunction() override = default;
  MS_DECLARE_PARENT(AbstractFunction, AbstractBase)


  virtual AbstractFunctionPtr GetUnique() = 0;

  TypePtr BuildType() const override { return std::make_shared<Function>(); }

  AbstractBasePtr Clone() const override { return Copy(); }

  AbstractBasePtr Broaden() const override {
    return const_cast<AbstractFunction *>(this)->shared_from_base<AbstractFunction>();
  }

  virtual AbstractFunctionPtr Copy() const = 0;

  AbstractBasePtr Join(const AbstractBasePtr &other) final;

  virtual AbstractFunctionPtr Join(const AbstractFunctionPtr &other) = 0;

  virtual void Visit(std::function<void(const AbstractFuncAtomPtr &)>) const = 0;

  bool operator==(const AbstractBase &other) const final;

  virtual bool operator==(const AbstractFunction &other) const = 0;

  static AbstractFunctionPtr MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list);

  virtual std::uintptr_t tracking_id() const { return 0; }

  virtual AbstractFunctionPtr CopyWithoutTrackingId() const { return Copy(); }

  virtual AnalysisContextPtr context() const { return nullptr; }

  static std::uintptr_t ToTrackingId(const AnfNodePtr &node) { return reinterpret_cast<std::uintptr_t>(node.get()); }
};

using AbstractFunctionPtrList = std::vector<AbstractFunctionPtr>;

class MS_CORE_API AbstractKeywordArg final : public AbstractBase {
 public:
  AbstractKeywordArg(const std::string &key, const AbstractBasePtr &argument) : arg_name_(key), arg_value_(argument) {}

  ~AbstractKeywordArg() override = default;
  MS_DECLARE_PARENT(AbstractKeywordArg, AbstractBase)

  TypePtr BuildType() const override;

  AbstractBasePtr Clone() const override;

  AbstractBasePtr Broaden() const override;

  std::size_t hash() const override;

  bool operator==(const AbstractKeywordArg &other) const;

  bool operator==(const AbstractBase &other) const override;

  std::string get_key() const { return arg_name_; }

  AbstractBasePtr get_arg() const { return arg_value_; }

  std::string ToString() const override;

 protected:
  ValuePtr RealBuildValue() const override;

 private:
  std::string arg_name_;
  AbstractBasePtr arg_value_;
};
using AbstractKeywordArgPtr = std::shared_ptr<AbstractKeywordArg>;

class MS_CORE_API AbstractUndetermined : public AbstractBase {
 public:
  AbstractUndetermined() : AbstractBase(kAnyValue) {}

  explicit AbstractUndetermined(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
      : AbstractBase(kAnyValue), element_(element) {
    if (element == nullptr) {
      MS_LOG(EXCEPTION) << "element is nullptr";
    }
    if (element->isa<AbstractUndetermined>()) {
      MS_LOG(EXCEPTION) << "element type error";
    }
    MS_EXCEPTION_IF_NULL(shape);
    if (shape->isa<NoShape>()) {
      MS_LOG(EXCEPTION) << "AbstractUndetermined can't set shape as NoShape.";
    }
    AbstractBase::set_shape(shape);
  }

  AbstractUndetermined(const TypePtr &element_type, const ShapeVector &shape)
      : AbstractBase(kAnyValue), element_(std::make_shared<AbstractScalar>(kAnyValue, element_type)) {
    if (element_type == nullptr) {
      MS_LOG(EXCEPTION) << "element_type is nullptr";
    }
    AbstractBase::set_shape(std::make_shared<Shape>(shape));
  }

  explicit AbstractUndetermined(const TypePtr &element_type, const BaseShapePtr &shape = std::make_shared<Shape>())
      : AbstractBase(kAnyValue), element_(std::make_shared<AbstractScalar>(kAnyValue, element_type)) {
    if (element_type == nullptr) {
      MS_LOG(EXCEPTION) << "element_type is nullptr";
    }
    MS_EXCEPTION_IF_NULL(shape);
    if (shape->isa<NoShape>()) {
      MS_LOG(EXCEPTION) << "AbstractUndetermined can't set shape as NoShape.";
    }
    AbstractBase::set_shape(shape);
  }

  ~AbstractUndetermined() override = default;
  MS_DECLARE_PARENT(AbstractUndetermined, AbstractBase)

  TypePtr BuildType() const override { return std::make_shared<UndeterminedType>(); }

  AbstractBasePtr Clone() const override { return std::make_shared<AbstractUndetermined>(); }

  AbstractBasePtr element() const { return element_; }

  ShapePtr shape() const;

  void set_shape(const BaseShapePtr &shape) override;

 protected:
  AbstractBasePtr element_;
};

class MS_CORE_API AbstractTensor : public AbstractUndetermined {
 public:
  explicit AbstractTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
      : AbstractUndetermined(element, shape) {}

  AbstractTensor(const TypePtr &element_type, const ShapeVector &shape) : AbstractUndetermined(element_type, shape) {}

  explicit AbstractTensor(const tensor::TensorPtr &tensor) : AbstractUndetermined(tensor->Dtype(), tensor->shape()) {}

  explicit AbstractTensor(const TypePtr &element_type, const BaseShapePtr &shape = std::make_shared<Shape>())
      : AbstractUndetermined(element_type, shape) {}

  ~AbstractTensor() override = default;
  MS_DECLARE_PARENT(AbstractTensor, AbstractUndetermined)


  void set_value_range(const ValuePtr &min_value, const ValuePtr &max_value) {
    min_value_ = min_value;
    max_value_ = max_value;
  }

  const ValuePtr &get_min_value() const { return min_value_; }

  const ValuePtr &get_max_value() const { return max_value_; }

  void set_shape_value(const ValuePtr &shape_value) { shape_value_ = shape_value; }

  const ValuePtr &get_shape_value() const { return shape_value_; }

  TypePtr BuildType() const override;

  BaseShapePtr BuildShape() const override;

  AbstractBasePtr Clone() const override;

  AbstractBasePtr Broaden() const override;

  AbstractBasePtr BroadenWithShape() const;

  AbstractBasePtr Join(const AbstractBasePtr &other) override;

  virtual bool operator==(const AbstractTensor &other) const;

  bool operator==(const AbstractBase &other) const override;

  std::string ToString() const override;

  std::size_t hash() const override {
    // We have to exclude value pointer from hash, because CSE (Common Subexpression Elimination)
    // will use this hash to find duplicate ValueNodes that Tensor values are equal.
    auto hash_sum = hash_combine(tid(), element_->hash());
    const auto &shape = GetShapeTrack();
    if (shape != nullptr) {
      hash_sum = hash_combine(hash_sum, shape->hash());
    }
    return hash_sum;
  }

  AbstractBasePtr PartialBroaden() const override;

 protected:
  bool equal_to(const AbstractTensor &other) const;
  ValuePtr min_value_ = nullptr;
  ValuePtr max_value_ = nullptr;
  ValuePtr shape_value_ = nullptr;
};
using AbstractTensorPtr = std::shared_ptr<AbstractTensor>;
using AbstractTensorPtrList = std::vector<AbstractTensorPtr>;

class MS_CORE_API AbstractSequence : public AbstractBase {
 public:
  explicit AbstractSequence(AbstractBasePtrList &&elements, const std::shared_ptr<AnfNodeWeakPtrList> &sequence_nodes);

  explicit AbstractSequence(const AbstractBasePtrList &elements,
                            const std::shared_ptr<AnfNodeWeakPtrList> &sequence_nodes);

  ~AbstractSequence() override = default;
  MS_DECLARE_PARENT(AbstractSequence, AbstractBase)


  TypePtrList ElementsType() const;

  BaseShapePtrList ElementsShape() const;

  AbstractBasePtrList ElementsClone() const;

  AbstractBasePtrList ElementsBroaden() const;

  AbstractBasePtrList ElementsPartialBroaden() const;

  template <typename T>
  ValuePtr ElementsBuildValue() const;

  template <typename T>
  AbstractBasePtr ElementsJoin(const AbstractBasePtr &other);

  AnfNodeWeakPtrList SequenceNodesJoin(const AbstractBasePtr &other);

  std::size_t size() const { return elements_.size(); }
  bool empty() const { return elements_.empty(); }
  const AbstractBasePtrList &elements() const { return elements_; }

  bool PurifyElements();

  const std::shared_ptr<AnfNodeWeakPtrList> &sequence_nodes() const { return sequence_nodes_; }

  void set_sequence_nodes(const std::shared_ptr<AnfNodeWeakPtrList> &sequence_nodes) {
    sequence_nodes_ = sequence_nodes;
  }

  void InsertSequenceNode(const AnfNodePtr &sequence_node);

  void InsertSequenceNodes(const AnfNodeWeakPtrList &sequence_nodes);

  void UpdateSequenceNode(const AnfNodePtr &old_sequence_node, const AnfNodePtr &new_sequence_node);

  std::size_t hash() const override;

  std::string ToStringInternal() const;
  std::string ToString() const override;
  std::string ToString(bool verbose) const override;

  const AbstractBasePtr operator[](const std::size_t &dim) const;

  virtual bool operator==(const AbstractSequence &other) const;

 protected:
  AbstractBasePtrList elements_;
  // Since there're not too many nodes, we just use vector here.
  std::shared_ptr<AnfNodeWeakPtrList> sequence_nodes_;
};
using AbstractSequencePtr = std::shared_ptr<AbstractSequence>;

class MS_CORE_API AbstractTuple : public AbstractSequence {
 public:
  explicit AbstractTuple(AbstractBasePtrList &&elements,
                         const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes = nullptr)
      : AbstractSequence(std::move(elements), tuple_nodes) {}

  explicit AbstractTuple(const AbstractBasePtrList &elements,
                         const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes = nullptr)
      : AbstractSequence(elements, tuple_nodes) {}

  ~AbstractTuple() override = default;
  MS_DECLARE_PARENT(AbstractTuple, AbstractSequence)


  void set_shape(const BaseShapePtr &shape) override;

  TypePtr BuildType() const override { return std::make_shared<Tuple>(ElementsType()); }

  BaseShapePtr BuildShape() const override { return std::make_shared<TupleShape>(ElementsShape()); }

  AbstractBasePtr Clone() const override { return std::make_shared<AbstractTuple>(ElementsClone(), sequence_nodes()); }

  AbstractBasePtr Broaden() const override {
    return std::make_shared<AbstractTuple>(ElementsBroaden(), sequence_nodes());
  }

  AbstractBasePtr PartialBroaden() const override {
    return std::make_shared<AbstractTuple>(ElementsPartialBroaden(), sequence_nodes());
  }

  AbstractBasePtr Join(const AbstractBasePtr &other) override {
    auto res = dyn_cast<AbstractSequence>(ElementsJoin<AbstractTuple>(other));
    MS_EXCEPTION_IF_NULL(res);
    res->InsertSequenceNodes(SequenceNodesJoin(other));
    return res;
  }

  bool ContainsAllBroadenTensors() const;

  bool operator==(const AbstractTuple &other) const;

  bool operator==(const AbstractBase &other) const override;

 protected:
  ValuePtr RealBuildValue() const override { return ElementsBuildValue<ValueTuple>(); }
};
using AbstractTuplePtr = std::shared_ptr<AbstractTuple>;

class MS_CORE_API AbstractList final : public AbstractSequence {
 public:
  explicit AbstractList(AbstractBasePtrList &&elements, const std::shared_ptr<AnfNodeWeakPtrList> &list_nodes = nullptr)
      : AbstractSequence(std::move(elements), list_nodes) {}

  explicit AbstractList(const AbstractBasePtrList &elements,
                        const std::shared_ptr<AnfNodeWeakPtrList> &list_nodes = nullptr)
      : AbstractSequence(elements, list_nodes) {}

  ~AbstractList() override = default;
  MS_DECLARE_PARENT(AbstractList, AbstractSequence)

  TypePtr BuildType() const override { return std::make_shared<List>(ElementsType()); }

  BaseShapePtr BuildShape() const override { return std::make_shared<ListShape>(ElementsShape()); }

  AbstractBasePtr Clone() const override { return std::make_shared<AbstractList>(ElementsClone(), sequence_nodes()); }

  AbstractBasePtr Broaden() const override {
    return std::make_shared<AbstractList>(ElementsBroaden(), sequence_nodes());
  }

  AbstractBasePtr PartialBroaden() const override {
    return std::make_shared<AbstractList>(ElementsPartialBroaden(), sequence_nodes());
  }

  AbstractBasePtr Join(const AbstractBasePtr &other) override {
    auto res = dyn_cast<AbstractSequence>(ElementsJoin<AbstractList>(other));
    MS_EXCEPTION_IF_NULL(res);
    res->InsertSequenceNodes(SequenceNodesJoin(other));
    return res;
  }

  bool operator==(const AbstractList &other) const;

  bool operator==(const AbstractBase &other) const override;

 protected:
  ValuePtr RealBuildValue() const override { return ElementsBuildValue<ValueList>(); }
};
using AbstractListPtr = std::shared_ptr<AbstractList>;

class MS_CORE_API AbstractDictionary final : public AbstractBase {
 public:
  explicit AbstractDictionary(const std::vector<AbstractAttribute> &key_values) : key_values_(key_values) {}

  ~AbstractDictionary() override = default;
  MS_DECLARE_PARENT(AbstractDictionary, AbstractBase)

  TypePtr BuildType() const override;

  bool operator==(const AbstractDictionary &other) const;

  bool operator==(const AbstractBase &other) const override;

  AbstractBasePtr Clone() const override;

  AbstractBasePtr Broaden() const override;

  std::string ToString() const override;

  std::size_t hash() const override;

  std::size_t size() const { return key_values_.size(); }

  const std::vector<AbstractAttribute> &elements() const { return key_values_; }

 protected:
  ValuePtr RealBuildValue() const override;
  std::vector<AbstractAttribute> key_values_;
};
using AbstractDictionaryPtr = std::shared_ptr<AbstractDictionary>;

class MS_CORE_API AbstractSlice final : public AbstractBase {
 public:
  AbstractSlice(const AbstractBasePtr &start, const AbstractBasePtr &stop, const AbstractBasePtr &step)
      : start_(start), stop_(stop), step_(step) {}

  ~AbstractSlice() override = default;
  MS_DECLARE_PARENT(AbstractSlice, AbstractBase)

  TypePtr BuildType() const override;

  bool operator==(const AbstractSlice &other) const;

  bool operator==(const AbstractBase &other) const override;

  AbstractBasePtr Clone() const override;

  AbstractBasePtr Broaden() const override;

  std::string ToString() const override;

  std::size_t hash() const override;

  AbstractBasePtr start() const { return start_; }

  AbstractBasePtr stop() const { return stop_; }

  AbstractBasePtr step() const { return step_; }

 protected:
  ValuePtr RealBuildValue() const override;

 private:
  AbstractBasePtr start_;
  AbstractBasePtr stop_;
  AbstractBasePtr step_;
};
using AbstractSlicePtr = std::shared_ptr<AbstractSlice>;

class MS_CORE_API AbstractJTagged final : public AbstractBase {
 public:
  explicit AbstractJTagged(const AbstractBasePtr &element) : element_(element) {}

  ~AbstractJTagged() override = default;
  MS_DECLARE_PARENT(AbstractJTagged, AbstractBase)

  TypePtr BuildType() const override;

  AbstractBasePtr Clone() const override { return std::make_shared<AbstractJTagged>(element_->Clone()); }

  AbstractBasePtr Broaden() const override { return std::make_shared<AbstractJTagged>(element_->Broaden()); }

  AbstractBasePtr Join(const AbstractBasePtr &other) override;

  bool operator==(const AbstractJTagged &other) const;

  bool operator==(const AbstractBase &other) const override;

  std::string ToString() const override;

  AbstractBasePtr element() { return element_; }

  std::size_t hash() const override { return hash_combine(tid(), element_->hash()); }

 private:
  AbstractBasePtr element_;
};
using AbstractJTaggedPtr = std::shared_ptr<AbstractJTagged>;

class MS_CORE_API AbstractNone final : public AbstractBase {
 public:
  AbstractNone() : AbstractBase() { set_type(std::make_shared<TypeNone>()); }

  ~AbstractNone() override = default;
  MS_DECLARE_PARENT(AbstractNone, AbstractBase)

  TypePtr BuildType() const override { return std::make_shared<TypeNone>(); }

  bool operator==(const AbstractNone &) const;

  bool operator==(const AbstractBase &other) const override;

  AbstractBasePtr Clone() const override { return std::make_shared<AbstractNone>(); }

  std::string ToString() const override;

 protected:
  ValuePtr RealBuildValue() const override;
};
using AbstractNonePtr = std::shared_ptr<AbstractNone>;

class MS_CORE_API AbstractNull final : public AbstractBase {
 public:
  AbstractNull() : AbstractBase(kNull) { set_type(std::make_shared<TypeNull>()); }

  ~AbstractNull() override = default;
  MS_DECLARE_PARENT(AbstractNull, AbstractBase)

  TypePtr BuildType() const override { return std::make_shared<TypeNull>(); }

  bool operator==(const AbstractNull &) const;

  bool operator==(const AbstractBase &other) const override;

  AbstractBasePtr Clone() const override { return std::make_shared<AbstractNull>(); }

  std::string ToString() const override;
};
using AbstractNullPtr = std::shared_ptr<AbstractNull>;

class MS_CORE_API AbstractTimeOut final : public AbstractBase {
 public:
  AbstractTimeOut() : AbstractBase(kNull) { set_type(std::make_shared<TypeNull>()); }

  ~AbstractTimeOut() override = default;
  MS_DECLARE_PARENT(AbstractTimeOut, AbstractBase)

  TypePtr BuildType() const override { return std::make_shared<TypeNull>(); }

  bool operator==(const AbstractTimeOut &) const;

  bool operator==(const AbstractBase &other) const override;

  AbstractBasePtr Clone() const override { return std::make_shared<AbstractTimeOut>(); }

  std::string ToString() const override;
};
using AbstractTimeOutPtr = std::shared_ptr<AbstractTimeOut>;

class MS_CORE_API AbstractEllipsis final : public AbstractBase {
 public:
  AbstractEllipsis() : AbstractBase(kEllipsis) { set_type(std::make_shared<TypeEllipsis>()); }

  ~AbstractEllipsis() override = default;
  MS_DECLARE_PARENT(AbstractEllipsis, AbstractBase)

  TypePtr BuildType() const override { return std::make_shared<TypeEllipsis>(); }

  bool operator==(const AbstractEllipsis &) const;

  bool operator==(const AbstractBase &other) const override;

  AbstractBasePtr Clone() const override { return std::make_shared<AbstractEllipsis>(); }

  std::string ToString() const override;
};
using AbstractEllipsisPtr = std::shared_ptr<AbstractEllipsis>;

class MS_CORE_API AbstractRefTensor final : public AbstractTensor {
 public:
  AbstractRefTensor(const AbstractTensorPtr &ref_value, const ValuePtr &ref_key_value);

  ~AbstractRefTensor() override = default;
  MS_DECLARE_PARENT(AbstractRefTensor, AbstractTensor)

  TypePtr BuildType() const override;

  bool operator==(const AbstractRefTensor &other) const;

  bool operator==(const AbstractBase &other) const override;

  AbstractBasePtr Clone() const override;

  AbstractBasePtr CloneAsTensor() const { return AbstractTensor::Clone(); }

  std::string ToString() const override;

  inline AbstractTensorPtr ref() { return shared_from_base<AbstractTensor>(); }

  inline ValuePtr ref_key_value() const { return ref_key_value_; }

  AbstractBasePtr Broaden() const override;

  virtual AbstractBasePtr Join(const std::shared_ptr<AbstractRefTensor> &other);
  AbstractBasePtr Join(const AbstractBasePtr &other) override;

  AbstractBasePtr PartialBroaden() const override;

 private:
  // ref_key_value is the reference key of AbstractRef, the value can be a string value or kAnyValue
  ValuePtr ref_key_value_;
};
using AbstractRefPtr = std::shared_ptr<AbstractRefTensor>;

MS_CORE_API std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list);

MS_CORE_API bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs);

struct AbstractBasePtrListHasher {
  std::size_t operator()(const AbstractBasePtrList &args_spec_list) const {
    return AbstractBasePtrListHash(args_spec_list);
  }
};

struct AbstractBasePtrListEqual {
  bool operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const {
    return AbstractBasePtrListDeepEqual(lhs, rhs);
  }
};

class MS_CORE_API AbstractSparseTensor : public AbstractTuple {
 public:
  explicit AbstractSparseTensor(AbstractBasePtrList &&elements,
                                const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes = nullptr)
      : AbstractTuple(std::move(elements), tuple_nodes) {}

  explicit AbstractSparseTensor(const AbstractBasePtrList &elements,
                                const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes = nullptr)
      : AbstractTuple(elements, tuple_nodes) {}

  ~AbstractSparseTensor() override = default;
  MS_DECLARE_PARENT(AbstractSparseTensor, AbstractTuple)

  template <typename T>
  const T GetAbsPtrAt(size_t index) const;
  BaseShapePtrList ElementsShapeTupleRecursive() const;
  TypePtr BuildType() const override;
  BaseShapePtr BuildShape() const override { return std::make_shared<TupleShape>(ElementsShapeTupleRecursive()); }

  const TypeId GetTensorTypeIdAt(size_t index) const;

  const TypeId GetShapeTypeIdAt(size_t index) const;

  const AbstractTuplePtr shape() const;
};
using AbstractSparseTensorPtr = std::shared_ptr<AbstractSparseTensor>;

class MS_CORE_API AbstractRowTensor final : public AbstractUndetermined {
 public:
  explicit AbstractRowTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
      : AbstractUndetermined(element, shape) {}

  AbstractRowTensor(const TypePtr &element_type, const ShapeVector &shape)
      : AbstractUndetermined(element_type, shape) {}

  ~AbstractRowTensor() override = default;
  MS_DECLARE_PARENT(AbstractRowTensor, AbstractUndetermined)


  const AbstractTensorPtr indices() const { return indices_; }

  void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; }

  const AbstractTensorPtr values() const { return values_; }

  void set_values(const AbstractTensorPtr &values) { values_ = values; }

  const AbstractTuplePtr dense_shape() const { return dense_shape_; }

  void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }

  TypePtr BuildType() const override;

  AbstractBasePtr Clone() const override;

  AbstractBasePtr Broaden() const override;

  AbstractBasePtr BroadenWithShape() const;

  std::string ToString() const override;

 private:
  std::shared_ptr<AbstractRowTensor> MakeAbstract(const BaseShapePtr &shp) const;
  AbstractTensorPtr indices_;
  AbstractTensorPtr values_;
  AbstractTuplePtr dense_shape_;
};
using AbstractRowTensorPtr = std::shared_ptr<AbstractRowTensor>;

// COOTensor is a Tuple with fixed number of elements and specific meaning of each position.
class MS_CORE_API AbstractCOOTensor : public AbstractSparseTensor {
 public:
  explicit AbstractCOOTensor(AbstractBasePtrList &&elements,
                             const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes = nullptr)
      : AbstractSparseTensor(std::move(elements), tuple_nodes) {}

  explicit AbstractCOOTensor(const AbstractBasePtrList &elements,
                             const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes = nullptr)
      : AbstractSparseTensor(elements, tuple_nodes) {}

  ~AbstractCOOTensor() override = default;
  MS_DECLARE_PARENT(AbstractCOOTensor, AbstractSparseTensor)

  const AbstractTensorPtr indices() const;
  const AbstractTensorPtr values() const;

  TypePtr BuildType() const override;
  AbstractBasePtr Clone() const override;
  AbstractBasePtr Broaden() const override;
  AbstractBasePtr PartialBroaden() const override;
  std::string ToString() const override;

  static constexpr size_t kIndicesIdx = 0;
  static constexpr size_t kValuesIdx = 1;
};
using AbstractCOOTensorPtr = std::shared_ptr<AbstractCOOTensor>;

// CSRTensor is a Tuple with fixed number of elements and specific meaning of each position.
class MS_CORE_API AbstractCSRTensor : public AbstractSparseTensor {
 public:
  explicit AbstractCSRTensor(AbstractBasePtrList &&elements,
                             const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes = nullptr)
      : AbstractSparseTensor(std::move(elements), tuple_nodes) {}

  explicit AbstractCSRTensor(const AbstractBasePtrList &elements,
                             const std::shared_ptr<AnfNodeWeakPtrList> &tuple_nodes = nullptr)
      : AbstractSparseTensor(elements, tuple_nodes) {}

  ~AbstractCSRTensor() override = default;
  MS_DECLARE_PARENT(AbstractCSRTensor, AbstractSparseTensor)

  const AbstractTensorPtr indptr() const;
  const AbstractTensorPtr indices() const;
  const AbstractTensorPtr values() const;

  TypePtr BuildType() const override;
  AbstractBasePtr Clone() const override;
  AbstractBasePtr Broaden() const override;
  AbstractBasePtr PartialBroaden() const override;
  std::string ToString() const override;

  static constexpr size_t kIndptrIdx = 0;
  static constexpr size_t kIndicesIdx = 1;
  static constexpr size_t kValuesIdx = 2;
};
using AbstractCSRTensorPtr = std::shared_ptr<AbstractCSRTensor>;

class MS_CORE_API AbstractMonad : public AbstractBase {
 public:
  ~AbstractMonad() override = default;
  MS_DECLARE_PARENT(AbstractMonad, AbstractBase)

  std::size_t hash() const override { return hash_combine({tid()}); }
  TypePtr BuildType() const override { return GetTypeTrack(); }
  AbstractBasePtr Broaden() const override { return AbstractBase::Broaden(); }
  AbstractBasePtr Join(const AbstractBasePtr &other) override = 0;
  std::string ToString() const override {
    std::ostringstream buffer;
    buffer << type_name() << "(" << GetValueTrack()->ToString() << ")";
    return buffer.str();
  }

 protected:
  AbstractMonad(const ValuePtr &value, const TypePtr &type) : AbstractBase(value, type) {}
};
using AbstractMonadPtr = std::shared_ptr<AbstractMonad>;

class MS_CORE_API AbstractUMonad final : public AbstractMonad {
 public:
  explicit AbstractUMonad(const ValuePtr &value = kUMonad) : AbstractMonad(value, kUMonadType) {}
  ~AbstractUMonad() override = default;
  MS_DECLARE_PARENT(AbstractUMonad, AbstractMonad)

  AbstractBasePtr Clone() const override { return std::make_shared<AbstractUMonad>(GetValueTrack()); }
  AbstractBasePtr Join(const AbstractBasePtr &other) override;
  bool operator==(const AbstractUMonad &) const;
  bool operator==(const AbstractBase &other) const override;
};
using AbstractUMonadPtr = std::shared_ptr<AbstractUMonad>;

class MS_CORE_API AbstractIOMonad final : public AbstractMonad {
 public:
  explicit AbstractIOMonad(const ValuePtr &value = kIOMonad) : AbstractMonad(value, kIOMonadType) {}
  ~AbstractIOMonad() override = default;
  MS_DECLARE_PARENT(AbstractIOMonad, AbstractMonad)

  AbstractBasePtr Clone() const override { return std::make_shared<AbstractIOMonad>(GetValueTrack()); }
  AbstractBasePtr Join(const AbstractBasePtr &other) override;
  bool operator==(const AbstractIOMonad &) const;
  bool operator==(const AbstractBase &other) const override;
};
using AbstractIOMonadPtr = std::shared_ptr<AbstractIOMonad>;

MS_CORE_API std::string ExtractLoggingInfo(const std::string &info);

MS_CORE_API void SynchronizeSequenceElementsUseFlagsRecursively(const AbstractSequencePtr &lhs_sequence,
                                                                const AbstractSequencePtr &rhs_sequence);
}  // namespace abstract
}  // namespace mindspore
#endif  // MINDSPORE_CORE_ABSTRACT_ABSTRACT_VALUE_H_