mindspore_serving.server

MindSpore Serving is a lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment.

MindSpore Serving server API, which can be used to start servables, gRPC and RESTful server. A servable corresponds to the service provided by a model. The client sends inference tasks and receives inference results through gRPC and RESTful server.

class mindspore_serving.server.SSLConfig(certificate, private_key, custom_ca=None, verify_client=False)[source]

The server’s ssl_config encapsulates necessary parameters for SSL-enabled connections.

Parameters
  • certificate (str) – File holding the PEM-encoded certificate chain as a byte string to use or None if no certificate chain should be used.

  • private_key (str) – File holding the PEM-encoded private key as a byte string, or None if no private key should be used.

  • custom_ca (str, optional) – File holding the PEM-encoded root certificates as a byte string. When verify_client is True, custom_ca must be provided. When verify_client is False, this parameter will be ignored. Default: None.

  • verify_client (bool, optional) – If verify_client is true, use mutual authentication. If false, use one-way authentication. Default: False.

Raises

RuntimeError – The type or value of the parameters are invalid.

class mindspore_serving.server.ServableStartConfig(servable_directory, servable_name, device_ids=None, version_number=0, device_type=None, num_parallel_workers=0, dec_key=None, dec_mode='AES-GCM')[source]

Servable startup configuration.

For more detail, please refer to MindSpore-based Inference Service Deployment and Servable Provided Through Model Configuration.

Parameters
  • servable_directory (str) – The directory where the servable is located in. There expects to has a directory named servable_name.

  • servable_name (str) – The servable name.

  • device_ids (Union[int, list[int], tuple[int]], optional) – The device list the model loads into and runs in. Used when device type is Nvidia GPU, Ascend 310P/910. Default None.

  • version_number (int, optional) – Servable version number to be loaded. The version number should be a positive integer, starting from 1, and 0 means to load the latest version. Default: 0.

  • device_type (str, optional) –

    Target device type for model deployment. Currently supports “Ascend”, “GPU”, “CPU” and None. Default: None.

    • ”Ascend”: the platform expected to be Ascend 310P/910, etc.

    • ”GPU”: the platform expected to be Nvidia GPU.

    • ”CPU”: the platform expected to be CPU.

    • None: the platform is determined by the MindSpore environment.

  • num_parallel_workers (int, optional) – The number of processes that process python tasks, at least the number of device cards used specified by the parameter device_ids. It will be adjusted to the number of device cards when it is less than the number of device cards. The value should be in range [0,64]. Default: 0.

  • dec_key (bytes, optional) – Byte type key used for decryption. The valid length is 16, 24, or 32. Default: None.

  • dec_mode (str, optional) – Specifies the decryption mode, take effect when dec_key is set. Option: ‘AES-GCM’ or ‘AES-CBC’. Default: ‘AES-GCM’.

Raises

RuntimeError – The type or value of the parameters are invalid.

mindspore_serving.server.start_grpc_server(address, max_msg_mb_size=100, ssl_config=None)[source]

Start gRPC server for the communication between serving client and server.

Parameters
  • address (str) –

    gRPC server address, the address can be {ip}:{port} or unix:{unix_domain_file_path}.

    • {ip}:{port} - Internet domain socket address.

    • unix:{unix_domain_file_path} - Unix domain socket address, which is used to communicate with multiple processes on the same machine. {unix_domain_file_path} can be relative or absolute file path, but the directory where the file is located must already exist.

  • max_msg_mb_size (int, optional) – The maximum acceptable gRPC message size in megabytes(MB), value range [1, 512]. Default: 100.

  • ssl_config (mindspore_serving.server.SSLConfig, optional) – The server’s ssl_config, if None, disabled ssl. Default: None.

Raises

RuntimeError – Failed to start the gRPC server: parameter verification failed, the gRPC address is wrong or the port is duplicate.

Examples

>>> from mindspore_serving import server
>>>
>>> server.start_grpc_server("0.0.0.0:5500")
mindspore_serving.server.start_restful_server(address, max_msg_mb_size=100, ssl_config=None)[source]

Start RESTful server for the communication between serving client and server.

Parameters
  • address (str) – RESTful server address, the address should be Internet domain socket address.

  • max_msg_mb_size (int, optional) – The maximum acceptable RESTful message size in megabytes(MB), value range [1, 512]. Default: 100.

  • ssl_config (mindspore_serving.server.SSLConfig, optional) – The server’s ssl_config, if None, disabled ssl. Default: None.

Raises

RuntimeError – Failed to start the RESTful server: parameter verification failed, the RESTful address is wrong or the port is duplicate.

Examples

>>> from mindspore_serving import server
>>>
>>> server.start_restful_server("0.0.0.0:5900")
mindspore_serving.server.start_servables(servable_configs, enable_lite=False)[source]

Used to start one or more servables on the serving server. One model can be combined with preprocessing and postprocessing to provide a servable, and multiple models can also be combined to provide a servable.

This interface can be used to start multiple different servables. One servable can be deployed on multiple devices, and each device runs a servable copy.

On Ascend 910 hardware platform, each copy of each servable owns one device. Different servables or different versions of the same servable need to be deployed on different devices. On Ascend 310P and GPU hardware platform, one device can be shared by multi servables, and different servables or different versions of the same servable can be deployed on the same chip to realize device reuse.

For details about how to configure models to provide servables, please refer to MindSpore-based Inference Service Deployment and Servable Provided Through Model Configuration.

Parameters
Raises

RuntimeError – Failed to start one or more servables. For log of one servable, please refer to subdirectory serving_logs of the directory where the startup script is located.

Examples

>>> import os
>>> from mindspore_serving import server
>>>
>>> servable_dir = os.path.abspath(".")
>>> resnet_config = server.ServableStartConfig(servable_dir, "resnet", device_ids=(0,1))
>>> add_config = server.ServableStartConfig(servable_dir, "add", device_ids=(2,3))
>>> server.start_servables(servable_configs=(resnet_config, add_config))  # press Ctrl+C to stop
>>> server.start_grpc_server("0.0.0.0:5500")
mindspore_serving.server.stop()[source]

Stop the running of serving server.

Examples

>>> from mindspore_serving import server
>>>
>>> server.start_grpc_server("0.0.0.0:5500")
>>> server.start_restful_server("0.0.0.0:1500")
>>> ...
>>> server.stop()

mindspore_serving.server.register

Servable register interface, used in servable_config.py of one servable. See how to configure servable_config.py file, please refer to Servable Provided Through Model Configuration.

class mindspore_serving.server.register.AscendDeviceInfo(**kwargs)[source]

Helper class to set Ascend device infos.

Parameters
  • insert_op_cfg_path (str, optional) – Path of aipp config file.

  • input_format (str, optional) – Manually specify the model input format, the value can be "ND", "NCHW", "NHWC", "CHWN", "NC1HWC0", or "NHWC1C0".

  • input_shape (str, optional) – Manually specify the model input shape, such as "input_op_name1: n1,c2,h3,w4;input_op_name2: n4,c3,h2,w1".

  • output_type (str, optional) – Manually specify the model output type, the value can be "FP16", "UINT8" or "FP32". Default: "FP32".

  • precision_mode (str, optional) – Model precision mode, the value can be "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" or "allow_mix_precision". Default: "force_fp16".

  • op_select_impl_mode (str, optional) – The operator selection mode, the value can be "high_performance" or "high_precision". Default: "high_performance".

  • fusion_switch_config_path (str, optional) – Configuration file path of the convergence rule, including graph convergence and UB convergence. The system has built-in graph convergence and UB convergence rules, which are enableed by default. You can disable the rules specified in the file by setting this parameter.

  • buffer_optimize_mode (str, optional) – The value can be "l1_optimize", "l2_optimize", "off_optimize" or "l1_and_l2_optimize". Default: "l2_optimize".

Raises

RuntimeError – Ascend device info is invalid.

Examples

>>> from mindspore_serving.server import register
>>> context = register.Context()
>>> context.append_device_info(register.AscendDeviceInfo(input_format="NCHW"))
>>> model = register.declare_model(model_file="deeptext.ms", model_format="MindIR_Lite", context=context)
class mindspore_serving.server.register.CPUDeviceInfo(**kwargs)[source]

Helper class to set cpu device info.

Parameters

precision_mode (str, optional) – Option of model precision, and the value can be "origin", "fp16". "origin" indicates that inference is performed with the preciesion defined in the model, and "fp16" indicates that inference is performed based on FP16 precision. Default: "origin".

Raises

RuntimeError – Cpu option is invalid, or value is not str.

Examples

>>> from mindspore_serving.server import register
>>> context = register.Context()
>>> context.append_device_info(register.CPUDeviceInfo(precision_mode="fp16"))
>>> model = register.declare_model(model_file="deeptext.ms", model_format="MindIR_Lite", context=context)
class mindspore_serving.server.register.Context(**kwargs)[source]

Context is used to customize device configurations. If Context is not specified, MindSpore Serving uses the default device configurations. When inference backend is MindSpore Lite and the device type is Ascend or Gpu, the extra CPUDeviceInfo will be used.

Parameters
  • thread_num (int, optional) – Set the number of threads at runtime. Only valid when using mindspore lite.

  • thread_affinity_core_list (tuple[int], list[int], optional) – Set the thread lists to CPU cores. Only valid when inference backend is MindSpore Lite.

  • enable_parallel (bool, optional) – Set the status whether to perform model inference or training in parallel. Only valid when inference backend is MindSpore Lite.

Raises

RuntimeError – type or value of input parameters are invalid.

Examples

>>> from mindspore_serving.server import register
>>> import numpy as np
>>> context = register.Context(thread_num=1, thread_affinity_core_list=[1,2], enable_parallel=True)
>>> context.append_device_info(register.GPUDeviceInfo(precision_mode="fp16"))
>>> model = declare_model(model_file="tensor_add.mindir", model_format="MindIR", context=context)
append_device_info(device_info)[source]

Append one user-defined device info to the context

Parameters

device_info (Union[CPUDeviceInfo, GPUDeviceInfo, AscendDeviceInfo]) – User-defined device info for one device, otherwise default values are used. You can customize device info for each device, and the system selects the required device info based on the actual backend device and MindSpore inference package.

Raises

RuntimeError – type or value of input parameters are invalid.

class mindspore_serving.server.register.GPUDeviceInfo(**kwargs)[source]

Helper class to set gpu device info.

Parameters

precision_mode (str, optional) – Option of model precision, and the value can be "origin", "fp16". "origin" indicates that inference is performed with the preciesion defined in the model, and "fp16" indicates that inference is performed based on FP16 precision. Default: "origin".

Raises

RuntimeError – Gpu option is invalid, or value is not str.

Examples

>>> from mindspore_serving.server import register
>>> context = register.Context()
>>> context.append_device_info(register.GPUDeviceInfo(precision_mode="fp16"))
>>> model = register.declare_model(model_file="deeptext.mindir", model_format="MindIR", context=context)
class mindspore_serving.server.register.Model(model_key)[source]

Indicate a model. User should not construct Model object directly, it’s need to be returned from declare_model or declare_servable

Parameters

model_key (str) – Model key identifies the model.

call(*args, subgraph=0)[source]

Invoke the model inference interface based on instances.

Parameters
  • args – tuple/list of instances, or inputs of one instance.

  • subgraph (int, optional) – Subgraph index, used when there are multiply sub-graphs in one model. Default: 0.

Returns

Tuple of instances when input parameter ‘args’ is tuple/list, or outputs of one instance.

Raises

RuntimeError – Inputs are invalid.

Examples

>>> import numpy as np
>>> from mindspore_serving.server import register
>>> import mindspore.dataset.vision.c_transforms as VC
>>> model = register.declare_model(model_file="resnet_bs32.mindir", model_format="MindIR") # batch_size=32
>>>
>>> def preprocess(image):
...     decode = VC.Decode()
...     resize = VC.Resize([224, 224])
...     normalize = VC.Normalize(mean=[125.307, 122.961, 113.8575], std=[51.5865, 50.847, 51.255])
...     hwc2chw = VC.HWC2CHW()
...     image = decode(image)
...     image = resize(image) # [3,224,224]
...     image = normalize(image) # [3,224,224]
...     image = hwc2chw(image) # [3,224,224]
...     return input
>>>
>>> def postprocess(score):
>>>     return np.argmax(score)
>>>
>>> def call_resnet_model(image):
...     image = preprocess(image)
...     score = model.call(image)  # for only one instance
...     return postprocess(score)
>>>
>>> def call_resnet_model_batch(instances):
...     input_instances = []
...     for instance in instances:
...         image = instance[0] # only one input
...         image = preprocess(image) # [3,224,224]
...         input_instances.append([image])
...     output_instances = model.call(input_instances)  # for multiply instances
...     for instance in output_instances:
...         score = instance[0]  # only one output for each instance
...         index = postprocess(score)
...         yield index
>>>
>>> @register.register_method(output_names=["index"])
>>> def predict_v1(image):  # without pipeline, call model with only one instance a time
...     index = register.add_stage(call_resnet_model, image, outputs_count=1)
...     return index
>>>
>>> @register.register_method(output_names=["index"])
>>> def predict_v2(image):  # without pipeline, call model with maximum 32 instances a time
...     index = register.add_stage(call_resnet_model_batch, image, outputs_count=1, batch_size=32)
...     return index
>>>
>>> @register.register_method(output_names=["index"])
>>> def predict_v3(image):  # pipeline
...     image = register.add_stage(preprocess, image, outputs_count=1)
...     score = register.add_stage(model, image, outputs_count=1)
...     index = register.add_stage(postprocess, score, outputs_count=1)
...     return index
mindspore_serving.server.register.add_stage(stage, *args, outputs_count, batch_size=None, tag=None)[source]

In the servable_config.py file of one servable, we use register_method to wrap a Python function to define a method of the servable, and add_stage is used to define a stage of this method, which can be a Python function or a model.

Note

The length of ‘args’ should be equal to the inputs number of function or model.

Parameters
  • stage (Union(function, Model)) – User-defined python function or Model object return by declare_model.

  • outputs_count (int) – Outputs count of the user-defined python function or model.

  • batch_size (int, optional) –

    This parameter is valid only when stage is a function and the function can process multi instances at a time. default None.

    • None, The input of the function will be the inputs of one instance.

    • 0, The input of the function will be tuple object of instances, and the maximum number of the instances is determined by the server based on the batch size of models.

    • int value >= 1, The input of the function will be tuple object of instances, and the maximum number of the instances is the value specified by ‘batch_size’.

  • args – Stage inputs placeholders, which come from the inputs of the function wrapped by register_method or the outputs of add_stage. The length of ‘args’ should equal to the input number of the function or model.

  • tag (str, optional) – Customized flag of the stage, such as "Preprocess", default None.

Raises

RuntimeError – The type or value of the parameters are invalid, or other error happened.

Examples

>>> import numpy as np
>>> from mindspore_serving.server import register
>>> add_model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR")
>>>
>>> def preprocess(x1, x2):
...     return x1.astype(np.float32), x2.astype(np.float32)
>>>
>>> @register.register_method(output_names=["y"]) # register add_common method in add
>>> def add_common(x1, x2):
...     x1, x2 = register.add_stage(preprocess, x1, x2, outputs_count=2) # call preprocess in stage 1
...     y = register.add_stage(add_model, x1, x2, outputs_count=1) # call add model in stage 2
...     return y
mindspore_serving.server.register.declare_model(model_file, model_format, with_batch_dim=True, options=None, without_batch_dim_inputs=None, context=None, config_file=None)[source]

Declare one model when importing servable_config.py of one servable.

Note

This interface should take effect when importing servable_config.py by the serving server. Therefore, it’s recommended that this interface be used globally in servable_config.py.

Warning

The parameter ‘options’ is deprecated from version 1.6.0 and will be removed in a future version, use parameter ‘context’ instead.

Parameters
  • model_file (Union[str, list[str]]) – Model files name.

  • model_format (str) – Model format, "MindIR" or "MindIR_Lite", case ignored.

  • with_batch_dim (bool, optional) – Whether the first shape dim of the inputs and outputs of model is batch dim. Default: True.

  • options (Union[AclOptions, GpuOptions], optional) – Options of model, supports AclOptions or GpuOptions. Default: None.

  • context (Context) – Context is used to store environment variables during execution. If the value is None, Serving uses the default device context based on the deployed device. Default: None.

  • without_batch_dim_inputs (Union[int, tuple[int], list[int]], optional) – Index of inputs that without batch dim when with_batch_dim is True. For example, if the shape of input 0 does not include the batch dimension, without_batch_dim_inputs can be set to (0,). Default: None.

  • config_file (str, optional) – Config file for model to set mix precision inference. The file path can be an absolute path or a relative path to the directory in which servable_config.py resides. Default: None.

Returns

Model, identification of this model, can be used for Model.call or as the inputs of add_stage.

Raises

RuntimeError – The type or value of the parameters are invalid.

mindspore_serving.server.register.register_method(output_names)[source]

Define a method of the servable when importing servable_config.py of one servable. One servable can include one or more methods, and eache method provides different services base on models. A client needs to specify the servable name and method name when accessing one service. MindSpore Serving supports a service consisting of multiple python functions and multiple models.

Note

This interface should take effect when importing servable_config.py by the serving server. Therefore, it’s recommended that this interface be used globally in servable_config.py.

This interface will define the signatures and pipeline of the method.

The signatures include the method name, input and outputs names of the method. When accessing a service, the client needs to specify the servable name, the method name, and provide one or more inference instances. Each instance specifies the input data by the input names and obtains the output data by the outputs names.

The pipeline consists of one or more stages, each stage can be a python function or a model. This is, a pipline can include one or more python functions and one or more models. In addition, the interface also defines the data flow of these stages.

Parameters

output_names (Union[str, tuple[str], list[str]]) – The output names of method. The input names is the args names of the registered function.

Raises

RuntimeError – The type or value of the parameters are invalid, or other error happened.

Examples

>>> from mindspore_serving.server import register
>>> add_model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR")
>>> sub_model = register.declare_model(model_file="tensor_sub.mindir", model_format="MindIR")
>>>
>>> @register.register_method(output_names=["y"]) # register predict method in servable
>>> def predict(x1, x2, x3): # x1+x2-x3
...     y = register.add_stage(add_model, x1, x2, outputs_count=1)
...     y = register.add_stage(sub_model, y, x3, outputs_count=1)
...     return y

mindspore_serving.server.distributed

The interface to startup serving server with distributed servable. See how to configure and startup distributed model, please refer to MindSpore Serving-based Distributed Inference Service Deployment.

mindspore_serving.server.distributed.declare_servable(rank_size, stage_size, with_batch_dim=True, without_batch_dim_inputs=None, enable_pipeline_infer=False)[source]

declare distributed servable in servable_config.py. For details, please refer to MindSpore Serving-based Distributed Inference Service Deployment.

Parameters
  • rank_size (int) – The rank size of the distributed model.

  • stage_size (int) – The stage size of the distributed model.

  • with_batch_dim (bool, optional) – Whether the first shape dim of the inputs and outputs of model is batch. Default: True.

  • without_batch_dim_inputs (Union[int, tuple[int], list[int]], optional) – Index of inputs that without batch dim when with_batch_dim is True. Default: None.

  • enable_pipeline_infer (bool, optional) – Whether to enable pipeline parallel inference. Pipeline parallelism can effectively improve inference performance. For details, see Pipeline Parallelism. Default: False.

Returns

Model, identification of this model, can be used for Model.call or as the inputs of add_stage.

Raises

RuntimeError – The type or value of the parameters are invalid.

Examples

>>> from mindspore_serving.server import distributed
>>> model = distributed.declare_servable(rank_size=8, stage_size=1)
mindspore_serving.server.distributed.start_servable(servable_directory, servable_name, rank_table_json_file, version_number=1, distributed_address='0.0.0.0:6200', wait_agents_time_in_seconds=0)[source]

Start up the servable named ‘servable_name’ defined in ‘servable_directory’.

Parameters
  • servable_directory (str) – The directory where the servable is located in. There expects to has a directory named servable_name. For more detail: How to config Servable .

  • servable_name (str) – The servable name.

  • version_number (int, optional) – Servable version number to be loaded. The version number should be a positive integer, starting from 1. Default: 1.

  • rank_table_json_file (str) – The rank table json file name.

  • distributed_address (str, optional) – The distributed worker address the worker agents linked to. Default: “0.0.0.0:6200”.

  • wait_agents_time_in_seconds (int, optional) – The maximum time in seconds the worker waiting ready of all agents, 0 means unlimited time. Default: 0.

Raises

RuntimeError – Failed to start the distributed servable.

Examples

>>> import os
>>> from mindspore_serving.server import distributed
>>>
>>> servable_dir = os.path.abspath(".")
>>> distributed.start_servable(servable_dir, "matmul", startup_worker_agents="hccl_8p.json", \
...                            distributed_address="127.0.0.1:6200")
mindspore_serving.server.distributed.startup_agents(distributed_address, model_files, group_config_files=None, agent_start_port=7000, agent_ip=None, rank_start=None, dec_key=None, dec_mode='AES-GCM')[source]

Start all required worker agents on the current machine. These worker agent processes are responsible for inference tasks on the local machine. For details, please refer to MindSpore Serving-based Distributed Inference Service Deployment.

Parameters
  • distributed_address (str) – The distributed worker address the agents linked to.

  • model_files (Union[list[str], tuple[str]]) – All model files need in current machine, absolute path or path relative to this startup python script.

  • group_config_files (Union[list[str], tuple[str]], optional) – All group config files need in current machine, absolute path or path relative to this startup python script, default None, which means there are no configuration files. Default: None.

  • agent_start_port (int, optional) – The starting agent port of the agents link to worker. Default: 7000.

  • agent_ip (str, optional) – The local agent ip, if it’s None, the agent ip will be obtained from rank table file. Default None. Parameter agent_ip and parameter rank_start must have values at the same time, or both None at the same time. Default: None.

  • rank_start (int, optional) – The starting rank id of this machine, if it’s None, the rank id will be obtained from rank table file. Default None. Parameter agent_ip and parameter must have values at the same time, or both None at the same time. Default: None.

  • dec_key (bytes, optional) – Byte type key used for decryption. The valid length is 16, 24, or 32. Default: None.

  • dec_mode (str, optional) – Specifies the decryption mode, take effect when dec_key is set. Option: 'AES-GCM' or 'AES-CBC'. Default: 'AES-GCM'.

Raises

RuntimeError – Failed to start agents.

Examples

>>> import os
>>> from mindspore_serving.server import distributed
>>> model_files = []
>>> for i in range(8):
>>>    model_files.append(f"models/device{i}/matmul.mindir")
>>> distributed.startup_agents(distributed_address="127.0.0.1:6200", model_files=model_files)