Custom Pass

View Source File

Overview

When it is necessary to change the structure of the computation graph, you can utilize MindSpore's custom pass feature to write pass logic, implement and register a custom pass plugin, and optimize the structure of the computation graph.

This tutorial provides a simple custom pass case as a demonstration. For more comprehensive examples, please refer to the examples.

Implementing Custom Pass

The implementation of custom pass requires completing the following steps:

  1. Reference mindspore/include/custom_pass_api.h header file.

  2. Inherit PatternToPatternPass class and implement DefineSrcPattern, DefineDstPattern and CheckMatchedDAG interfaces.

  3. Inherit CustomPassPlugin class and implement GetPluginName, GetAvailablePassNames and CreatePass interfaces.

  4. Register custom backend by using the EXPORT_CUSTOM_PASS_PLUGIN macro.

Here, we implement a simple AddNegFusionPass and a custom Pass plugin to replace the Add operator and Neg operator with a Sub operator.

// add_neg_fusion_pass.h
// header file of AddNegFusionPass
#ifndef MINDSPORE_CUSTOM_PASS_ADD_NEG_FUSION_PASS_H_
#define MINDSPORE_CUSTOM_PASS_ADD_NEG_FUSION_PASS_H_

#include "mindspore/include/custom_pass_api.h"

namespace mindspore {
namespace opt {
/**
 * @brief Pass to fuse Add and Neg operations into Sub
 *
 * Transforms Add(x, Neg(y)) into Sub(x, y)
 * This is a standard algebraic optimization that eliminates unnecessary Neg operations
 * Works on CPU/GPU/Ascend since all platforms support Add, Neg, and Sub operations
 * Inherits from PatternToPatternPass to comply with MindSpore plugin system requirements
 */
class AddNegFusionPass : public PatternToPatternPass {
 public:
  AddNegFusionPass() : PatternToPatternPass("AddNegFusionPass") {}

  void DefineSrcPattern(SrcPattern *src_pattern) override;
  void DefineDstPattern(DstPattern *dst_pattern) override;
  bool CheckMatchedDAG(const PatternMap &pattern_map, const FuncGraphPtr &func_graph,
                       const AnfNodePtr &node) const override;

 private:
  static bool IsAddNode(const AnfNodePtr &node);
  static bool IsNegNode(const AnfNodePtr &node);

  static AnfNodePtr BuildSub(const PatternMap &m, const AnfNodePtr &default_node);
};
}  // namespace opt
}  // namespace mindspore
#endif  // MINDSPORE_CUSTOM_PASS_ADD_NEG_FUSION_PASS_H_
// add_neg_fusion_pass.cc
// cpp file of AddNegFusionPass
#include "add_neg_fusion_pass.h"

namespace mindspore {
namespace opt {
void AddNegFusionPass::DefineSrcPattern(SrcPattern *src_pattern) {
  MS_LOG(INFO) << "Defining source pattern for AddNegFusionPass";
  MS_EXCEPTION_IF_NULL(src_pattern);

  // Pattern: Add(x, Neg(y))
  (*src_pattern)
    .AddVar("x")
    .AddVar("y")
    .AddCNode("neg", {std::make_shared<Primitive>("Neg"), "y"})
    .AddCNode("add", {std::make_shared<Primitive>("Add"), "x", "neg"});

  MS_LOG(INFO) << "Source pattern defined: Add(x, Neg(y))";
}

AnfNodePtr AddNegFusionPass::BuildSub(const PatternMap &m, const AnfNodePtr &default_node) {
  auto add_node = m.Get("add")->cast<CNodePtr>();
  auto neg_node = m.Get("neg")->cast<CNodePtr>();
  MS_EXCEPTION_IF_NULL(add_node);
  MS_EXCEPTION_IF_NULL(neg_node);

  auto sub_node = default_node->cast<CNodePtr>();
  MS_EXCEPTION_IF_NULL(sub_node);

  // Copy Add node's scope to maintain execution context
  sub_node->set_scope(add_node->scope());

  // Set abstract same as Add output
  auto add_abstract = add_node->abstract();
  if (add_abstract != nullptr) {
    sub_node->set_abstract(add_abstract->Clone());
  } else {
    MS_LOG(EXCEPTION) << "Failed to create Sub abstract from Add node";
  }

  return sub_node;
}

void AddNegFusionPass::DefineDstPattern(DstPattern *dst_pattern) {
  MS_LOG(INFO) << "Defining destination pattern for AddNegFusionPass";
  MS_EXCEPTION_IF_NULL(dst_pattern);

  // Replace with Sub(x, y) - directly subtract y instead of adding its negation
  (*dst_pattern).AddCNode("sub", {std::make_shared<Primitive>("Sub"), "x", "y"}, BuildSub);

  MS_LOG(INFO) << "Destination pattern defined: Sub(x, y)";
}

bool AddNegFusionPass::CheckMatchedDAG(const PatternMap &pattern_map, const FuncGraphPtr &func_graph,
                                       const AnfNodePtr &node) const {
  auto add_node = pattern_map.Get("add");
  if (!add_node) {
    MS_LOG(ERROR) << "Add node not found in pattern match";
    return false;
  }

  auto neg_node = pattern_map.Get("neg");
  if (!neg_node) {
    MS_LOG(ERROR) << "Neg node not found in pattern match";
    return false;
  }

  auto x_node = pattern_map.Get("x");
  if (!x_node) {
    MS_LOG(ERROR) << "x node not found in pattern match";
    return false;
  }

  auto y_node = pattern_map.Get("y");
  if (!y_node) {
    MS_LOG(ERROR) << "y node not found in pattern match";
    return false;
  }

  MS_LOG(INFO) << "AddNeg fusion pattern matched successfully";
  return true;
}
}  // namespace opt
}  // namespace mindspore
// ms_custom_pass_plugin.cc
// cpp file of Custom Pass Plugin
#include <string>
#include <memory>
#include <vector>
#include "mindspore/ccsrc/include/backend/common/custom_pass/custom_pass_plugin.h"
#include "add_neg_fusion_pass.h"

namespace mindspore {
namespace opt {

class MSCustomPassPlugin : public CustomPassPlugin {
 public:
  std::string GetPluginName() const override { return "ms_custom_pass_plugin"; }

  std::vector<std::string> GetAvailablePassNames() const override {
    return {"ReplaceAddNFusionPass", "AddNegFusionPass"};
  }

  std::shared_ptr<Pass> CreatePass(const std::string &pass_name) const override {
    if (pass_name == "AddNegFusionPass") {
      auto pass = std::make_shared<AddNegFusionPass>();
      MS_LOG(INFO) << "Created pass '" << pass_name << "' successfully";
      return pass;
    } else {
      MS_LOG(WARNING) << "Pass '" << pass_name << "' not found, available: ReplaceAddNFusionPass, AddNegFusionPass";
      return nullptr;
    }
  }
};
}  // namespace opt
}  // namespace mindspore

EXPORT_CUSTOM_PASS_PLUGIN(mindspore::opt::MSCustomPassPlugin)

Compiling Custom Pass Plugin

Compile the above example code into libcustom_pass.so. The CMake script is as follows:

cmake_minimum_required(VERSION 3.16)
project(pass VERSION 1.0.0 LANGUAGES CXX)

# Set C++ standard
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

# Use specified MindSpore path
set(MINDSPORE_INCLUDE_DIR ${MINDSPORE_ROOT}/include)
set(MINDSPORE_LIB_DIRS ${MINDSPORE_ROOT}/lib)
message(STATUS "Using MindSpore from: ${MINDSPORE_ROOT}")

# Build options configuration (simplified)
set(CMAKE_BUILD_TYPE "Release")

message(STATUS "MindSpore root directory: ${MINDSPORE_ROOT}")

# Find MindSpore library files - corrected path
file(GLOB MINDSPORE_LIBS "${MINDSPORE_ROOT}/lib/libmindspore_*.so")

message(STATUS "Found library files count: ${MINDSPORE_LIBS}")

if(MINDSPORE_LIBS)
    # Select main library file for checking (usually libmindspore_core.so contains core functionality)
    list(FIND MINDSPORE_LIBS "${MINDSPORE_ROOT}/lib/libmindspore_core.so" CORE_LIB_INDEX)
    message(STATUS "libmindspore_core.so index: ${CORE_LIB_INDEX}")

    if(CORE_LIB_INDEX GREATER_EQUAL 0)
        list(GET MINDSPORE_LIBS ${CORE_LIB_INDEX} MINDSPORE_LIB)
    else()
        list(GET MINDSPORE_LIBS 0 MINDSPORE_LIB)
    endif()

    message(STATUS "Selected library file: ${MINDSPORE_LIB}")

    # Check if MindSpore library contains robin_hood symbols
    message(STATUS "Starting robin_hood symbol check...")

    # Method 1: Use bash to execute command
    execute_process(
        COMMAND bash -c "strings '${MINDSPORE_LIB}' | grep -i robin_hood | head -1"
        OUTPUT_VARIABLE ROBIN_HOOD_CHECK
        ERROR_VARIABLE ROBIN_HOOD_CHECK_ERROR
        OUTPUT_STRIP_TRAILING_WHITESPACE
    )

    message(STATUS "robin_hood check result: '${ROBIN_HOOD_CHECK}'")
    message(STATUS "robin_hood check error: '${ROBIN_HOOD_CHECK_ERROR}'")

    # Method 2: If bash method fails, try reading file directly
    if(NOT ROBIN_HOOD_CHECK)
        message(STATUS "bash method failed, trying to read file directly...")
        file(READ ${MINDSPORE_LIB} LIB_CONTENT)
        string(FIND "${LIB_CONTENT}" "robin_hood" ROBIN_HOOD_POS)
        if(ROBIN_HOOD_POS GREATER_EQUAL 0)
            set(ROBIN_HOOD_CHECK "found_in_file")
            message(STATUS "Found robin_hood in file content")
        endif()
    endif()

    if(ROBIN_HOOD_CHECK)
        message(STATUS "MindSpore uses robin_hood::unordered_map")

        # Check if robin_hood.h exists
        if(EXISTS "${MINDSPORE_ROOT}/include/third_party/robin_hood_hashing/include/robin_hood.h")
            message(STATUS "Found robin_hood.h: "
                    "${MINDSPORE_ROOT}/include/third_party/robin_hood_hashing/include/robin_hood.h")
            add_compile_definitions(ENABLE_FAST_HASH_TABLE=1)
            add_compile_definitions(HASHMAP_TYPE="robin_hood")
            # Add robin_hood header file path
            include_directories("${MINDSPORE_ROOT}/include/third_party/robin_hood_hashing")
            message(STATUS "Using fast hash table (robin_hood) for ms_custom_ops to match MindSpore")
        else()
            message(FATAL_ERROR "robin_hood.h not found under mindspore install path")
        endif()
    else()
        message(STATUS "MindSpore uses std::unordered_map")
        add_compile_definitions(HASHMAP_TYPE="std")
        message(STATUS "Using standard hash table (std::unordered_map) for ms_custom_ops to match MindSpore")
    endif()
else()
    message(FATAL_ERROR "MindSpore library not found in ${MINDSPORE_ROOT}/mindspore/lib/")
endif()

# Include directories
include_directories(${CMAKE_CURRENT_SOURCE_DIR})

# Handle multiple MindSpore include directories
if(MINDSPORE_INCLUDE_DIR)
    # Add complete MindSpore include paths to ensure all dependency headers are found
    include_directories(${MINDSPORE_INCLUDE_DIR})
    include_directories(${MINDSPORE_INCLUDE_DIR}/mindspore)
    include_directories(${MINDSPORE_INCLUDE_DIR}/mindspore/core/include)
    # Add MindSpore ccsrc path
    include_directories(${MINDSPORE_INCLUDE_DIR}/mindspore/ccsrc/)
    include_directories(${MINDSPORE_INCLUDE_DIR}/mindspore/ccsrc/include/)
    # Add MindSpore ops path
    include_directories(${MINDSPORE_INCLUDE_DIR}/mindspore/ops)
    include_directories(${MINDSPORE_INCLUDE_DIR}/mindspore/ops/include)
    include_directories(${MINDSPORE_INCLUDE_DIR}/mindspore/ops/kernel/include)
    # Add third_party path, contains securec.h
    include_directories(${MINDSPORE_INCLUDE_DIR}/third_party)
    # Add specific securec path
    include_directories(${MINDSPORE_INCLUDE_DIR}/third_party/securec/include)
endif()

# Automatically find all source files
file(GLOB_RECURSE PASS_SOURCES "*.cc")
file(GLOB_RECURSE PASS_HEADERS "*.h")

# Create dynamic library (based on installed MindSpore)
add_library(custom_pass SHARED ${PASS_SOURCES})

# Link MindSpore libraries (based on actual requirements)
target_link_libraries(custom_pass
    ${MINDSPORE_LIB_DIRS}/libmindspore_backend_common.so
    ${MINDSPORE_LIB_DIRS}/libmindspore_core.so
    ${MINDSPORE_LIB_DIRS}/libmindspore_common.so
)

# Default settings
option(ENABLE_GLIBCXX "enable_glibcxx" OFF)

# System-related overrides
if(NOT CMAKE_SYSTEM_NAME MATCHES "Linux")
    set(ENABLE_GLIBCXX ON)
endif()

# Environment variable overrides
if(DEFINED ENV{ENABLE_GLIBCXX})
    set(ENABLE_GLIBCXX $ENV{ENABLE_GLIBCXX})
endif()

# ABI flag settings
if(CMAKE_SYSTEM_NAME MATCHES "Linux")
    if(NOT ENABLE_GLIBCXX)
        add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
    endif()
endif()

# Set compilation options
target_compile_options(custom_pass PRIVATE
    -fPIC
    -std=c++17
    -Wall
    -Wextra
)

# Use ABI settings consistent with MindSpore
if(CMAKE_SYSTEM_NAME MATCHES "Linux")
    if(NOT ENABLE_GLIBCXX)
        target_compile_definitions(custom_pass PRIVATE _GLIBCXX_USE_CXX11_ABI=0)
    endif()
endif()

# Set compilation definitions
target_compile_definitions(custom_pass PRIVATE
    -DPASS_PLUGIN_EXPORTS
    -DMINDSPORE_PASS
)

# Set dynamic library properties
set_target_properties(custom_pass PROPERTIES
    VERSION ${PROJECT_VERSION}
    SOVERSION ${PROJECT_VERSION_MAJOR}
    PREFIX "lib"
    OUTPUT_NAME "custom_pass"
)

# Installation rules
install(TARGETS custom_pass
    LIBRARY DESTINATION lib
    RUNTIME DESTINATION bin
)

The compilation command is as follows:

cmake . -DMINDSPORE_ROOT=/path/to/mindspore
make

/path/to/mindspore represents the installation path of MindSpore.

Using Custom Pass

Using mindspore.graph.register_custom_pass to register and enable the custom pass:

import numpy as np
import mindspore
from mindspore import jit, ops, nn, Tensor

custom_path = "/data1/libcustom_pass.so"
success = mindspore.graph.register_custom_pass("AddNegFusionPass", custom_path, "cpu")
assert success, "Plugin registration failed"

class AddNegNetwork(nn.Cell):
    def __init__(self):
        super().__init__()
        self.neg = ops.Neg()

    @jit(backend="ms_backend")
    def construct(self, x1, x2):
        # Neg operation: -x2
        neg_x2 = self.neg(x2)
        # Add operation: x1 + (-x2) = x1 - x2
        output = x1 + neg_x2
        return output

mindspore.set_device("CPU")
net = AddNegNetwork()
x1 = Tensor(np.array([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
                        [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]]).astype(np.float32))
x2 = Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]],
                        [[4, 4, 4, 4], [5, 5, 5, 5], [6, 6, 6, 6]]]).astype(np.float32))
output = net(x1, x2)

# Verify functional correctness
expected = x1.asnumpy() - x2.asnumpy()  # x1 + (-x2) = x1 - x2
np.testing.assert_array_almost_equal(output.asnumpy(), expected)