Introduction to a New Method for Edge-Cloud Collaborative Training for Privacy Protection
Introduction to a New Method for Edge-Cloud Collaborative Training for Privacy Protection
Authors: Wang Sen, Wang Peng, Yao Xin, Cui Jinkai, Hu Qintao, Chen Renhai, Zhang Gong | Organization: Theory Lab, 2012 Laboratories
Paper Title
MistNet: Towards Private Neural Network Training with Local Differential Privacy
Paper URL: https://github.com/TL-System/plato/blob/main/docs/papers/MistNet.pdf
Code URLs
Plato: https://github.com/TL-System/plato
Sedna: https://github.com/kubeedge/sedna
01 Research Background
Since Google first proposed federated learning in the edge AI field, it has been a rapidly developing topic in both academia and industry. Two major challenges in edge AI are data heterogeneity and data privacy, which can be addressed by applying federated learning to edge computing. FedAvg, an algorithm used in federated learning, selects clients to participate in training during each round. This reduces the communication pressure and avoids unreliable communication. Additionally, clients only need to upload training gradients, which helps prevent the leakage of user data. Nevertheless, FedAvg still faces three main bottlenecks:
(1) As the size of the model increases, the volume of data transmitted surges, which can become a bottleneck that hinders system performance.
(2) Research in deep learning has shown that gradients can still contain information about the native data, allowing attackers to potentially infer users' private data.
(3) Edge computing capabilities vary greatly, with some devices being unable to complete the training process or slowing down the synchronization progress of federated learning due to insufficient computing power.
02 Paper Abstract
To address the three main issues of the FedAvg algorithm, we propose the MistNet algorithm. This algorithm divides a pre-trained DNN model into two parts: a feature extractor at the edge side and a classifier on the cloud. Deep learning training rules show that new data rarely updates the parameters of the feature extractor, but does update the parameters of the classifier. As a result, we keep the edge-side parameters fixed and use the feature extractor to process input data and obtain corresponding representation data. Then we send the representation data from the client to the server, and train the classifier on the cloud. The MistNet algorithm has been optimized according to the following edge scenarios:
(1) Reduces the volume of network transmission required for communication between the edge and the cloud. Instead of performing multiple rounds of gradient transmission between the cloud and edge, as is done in traditional federated learning, the extracted representation data is transmitted to the cloud for aggregated training. This reduces the frequency of network transmission between the cloud and edge, thereby reducing the overall volume of data transmitted for communication between the two.
(2) Enhances privacy protection by quantifying, adding noise to, compressing and disturbing the representation data. This makes it more difficult to infer the original data from the representation data on the cloud, thereby increasing the level of privacy protection for the data.
(3) Reduces computing resource requirements on the edge side by segmenting the pre-trained model and using the first several layers as a feature extractor, thereby reducing computing workloads on the client. The process of extracting features on the edge can be considered as an inference process, which allows federated learning to be completed using edge-side hardware that has only inference capabilities.
Experiments have shown that the MistNet algorithm can significantly reduce communication overheads and edge computing workloads compared to the FedAvg algorithm, with reductions of up to five times and ten times, respectively. Additionally, the training accuracy of the MistNet algorithm is better than that of FedAvg, with an improvement in convergence efficiency for automatic training in object detection tasks of up to 30%.
03 Algorithm Framework and Technical Key Points
Technical Key Point 1: Model Segmentation and Representation Migration
By utilizing the migration feature of the first several layers of a deep neural network, the server can train a model using existing data from a related or similar field and extract the first several layers to use as a feature extractor. The client can then obtain the feature extractor from a secure third party or server and randomly select the feature extractor and local data for fine-tuning.

Figure 1: Schematic diagram of feature extraction
Technical Key Point 2: Quantization Solution for Representation Data
The communication volume can be effectively reduced by quantizing and compressing the representation data at the middle layer. An extreme solution is to use 1-bit quantization on the output of the activation function. Although this causes most of the representation data content to be lost, it effectively prevents information leakage.

Figure 4: One-click deployment of the edge-cloud collaborative training framework for privacy protection on the Sedna platform
Software and Hardware
Hardware: Atlas 800 (9000) + Atlas 500 (3000)
Software: Ubuntu 18.04.5 LTS x86_64 + EulerOS V2R8 + CANN 5.0.2 + KubeEdge 1.8.2 + Sedna 0.4.0
Test Results

(3) The complex model has a stronger resistance to noise. For 1.3% and 5.8% feature extractors, a good balance between privacy protection and precision is achieved when Ɛ is 1.

Figure 7. Defense effect against model inversion attacks.
We perform white-box tests to simulate model inversion attacks and use SSIM to verify the effect. If the SSIM is less than 0.3, the original image cannot be identified. As shown in the preceding figure, most feature extractors can effectively defend against model inversion attacks after using 1-bit quantization and LDP.
05 Code Implementation of NPU + MindSpore + YOLOv5
The code mainly includes the modules for data loading, network design, data privacy protection design, loss function design, and trainer design.
Data loading module:
def _has_only_empty_bbox(anno):
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
def _count_visible_keypoints(anno):
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
def has_valid_annotation(anno):
"""Check annotation file."""
# if it's empty, there is no annotation
if not anno:
return False
# if all boxes have close to zero area, there is no annotation
if _has_only_empty_bbox(anno):
return False
# keypoints task have a slight different criteria for considering
# if an annotation is valid
if "keypoints" not in anno[0]:
return True
# for keypoint detection tasks, only consider valid images those
# containing at least min_keypoints_per_image
if _count_visible_keypoints(anno) >= min_keypoints_per_image:
return True
return False
class COCOYoloDataset:
"""YOLOV5 Dataset for COCO."""
def __init__(self, root, ann_file, remove_images_without_annotations=True,
filter_crowd_anno=True, is_training=True):
self.coco = COCO(ann_file)
self.root = root
self.img_ids = list(sorted(self.coco.imgs.keys()))
self.filter_crowd_anno = filter_crowd_anno
self.is_training = is_training
self.mosaic = True
# filter images without any annotations
if remove_images_without_annotations:
img_ids = []
for img_id in self.img_ids:
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = self.coco.loadAnns(ann_ids)
if has_valid_annotation(anno):
img_ids.append(img_id)
self.img_ids = img_ids
self.categories = {cat["id"]: cat["name"] for cat in self.coco.cats.values()}
self.cat_ids_to_continuous_ids = {
v: i for i, v in enumerate(self.coco.getCatIds())
}
self.continuous_ids_cat_ids = {
v: k for k, v in self.cat_ids_to_continuous_ids.items()
}
self.count = 0
def _mosaic_preprocess(self, index, input_size):
labels4 = []
s = 384
self.mosaic_border = [-s // 2, -s // 2]
yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border]
indices = [index] + [random.randint(0, len(self.img_ids) - 1) for _ in range(3)]
for i, img_ids_index in enumerate(indices):
coco = self.coco
img_id = self.img_ids[img_ids_index]
img_path = coco.loadImgs(img_id)[0]["file_name"]
img = Image.open(os.path.join(self.root, img_path)).convert("RGB")
img = np.array(img)
h, w = img.shape[:2]
if i == 0: # top left
img4 = np.full((s * 2, s * 2, img.shape[2]), 128, dtype=np.uint8) # base image with 4 tiles
x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
elif i == 1: # top right
x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
elif i == 2: # bottom left
x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
elif i == 3: # bottom right
x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
padw = x1a - x1b
padh = y1a - y1b
ann_ids = coco.getAnnIds(imgIds=img_id)
target = coco.loadAnns(ann_ids)
# filter crowd annotations
if self.filter_crowd_anno:
annos = [anno for anno in target if anno["iscrowd"] == 0]
else:
annos = [anno for anno in target]
target = {}
boxes = [anno["bbox"] for anno in annos]
target["bboxes"] = boxes
classes = [anno["category_id"] for anno in annos]
classes = [self.cat_ids_to_continuous_ids[cl] for cl in classes]
target["labels"] = classes
bboxes = target['bboxes']
labels = target['labels']
out_target = []
for bbox, label in zip(bboxes, labels):
tmp = []
# convert to [x_min y_min x_max y_max]
bbox = self._convetTopDown(bbox)
tmp.extend(bbox)
tmp.append(int(label))
# tmp [x_min y_min x_max y_max, label]
out_target.append(tmp) # out_target indicates the actual width and height of label, which corresponds to the actual measurement values of the image.
labels = out_target.copy()
labels = np.array(labels)
out_target = np.array(out_target)
labels[:, 0] = out_target[:, 0] + padw
labels[:, 1] = out_target[:, 1] + padh
labels[:, 2] = out_target[:, 2] + padw
labels[:, 3] = out_target[:, 3] + padh
labels4.append(labels)
if labels4:
labels4 = np.concatenate(labels4, 0)
np.clip(labels4[:, :4], 0, 2 * s, out=labels4[:, :4]) # use with random_perspective
flag = np.array([1])
return img4, labels4, input_size, flag
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
(img, target) (tuple): target is a dictionary contains "bbox", "segmentation" or "keypoints",
generated by the image's annotation. img is a PIL image.
"""
coco = self.coco
img_id = self.img_ids[index]
img_path = coco.loadImgs(img_id)[0]["file_name"]
if not self.is_training:
img = Image.open(os.path.join(self.root, img_path)).convert("RGB")
return img, img_id
input_size = [640, 640]
if self.mosaic and random.random() < 0.5:
return self._mosaic_preprocess(index, input_size)
img = np.fromfile(os.path.join(self.root, img_path), dtype='int8')
ann_ids = coco.getAnnIds(imgIds=img_id)
target = coco.loadAnns(ann_ids)
# filter crowd annotations
if self.filter_crowd_anno:
annos = [anno for anno in target if anno["iscrowd"] == 0]
else:
annos = [anno for anno in target]
target = {}
boxes = [anno["bbox"] for anno in annos]
target["bboxes"] = boxes
classes = [anno["category_id"] for anno in annos]
classes = [self.cat_ids_to_continuous_ids[cl] for cl in classes]
target["labels"] = classes
bboxes = target['bboxes']
labels = target['labels']
out_target = []
for bbox, label in zip(bboxes, labels):
tmp = []
# convert to [x_min y_min x_max y_max]
bbox = self._convetTopDown(bbox)
tmp.extend(bbox)
tmp.append(int(label))
# tmp [x_min y_min x_max y_max, label]
out_target.append(tmp)
flag = np.array([0])
return img, out_target, input_size, flag
def __len__(self):
return len(self.img_ids)
def _convetTopDown(self, bbox):
x_min = bbox[0]
y_min = bbox[1]
w = bbox[2]
h = bbox[3]
return [x_min, y_min, x_min+w, y_min+h]
def create_yolo_dataset(image_dir, anno_path, batch_size, max_epoch, device_num, rank,
config=None, is_training=True, shuffle=True):
"""Create dataset for YOLOV5."""
cv2.setNumThreads(0)
de.config.set_enable_shared_mem(True)
if is_training:
filter_crowd = True
remove_empty_anno = True
else:
filter_crowd = False
remove_empty_anno = False
yolo_dataset = COCOYoloDataset(root=image_dir, ann_file=anno_path, filter_crowd_anno=filter_crowd,
remove_images_without_annotations=remove_empty_anno, is_training=is_training)
distributed_sampler = DistributedSampler(len(yolo_dataset), device_num, rank, shuffle=shuffle)
yolo_dataset.size = len(distributed_sampler)
hwc_to_chw = CV.HWC2CHW()
config.dataset_size = len(yolo_dataset)
cores = multiprocessing.cpu_count()
num_parallel_workers = int(cores / device_num)
if is_training:
multi_scale_trans = MultiScaleTrans(config, device_num)
yolo_dataset.transforms = multi_scale_trans
dataset_column_names = ["image", "annotation", "input_size", "mosaic_flag"]
output_column_names = ["image", "annotation", "bbox1", "bbox2", "bbox3",
"gt_box1", "gt_box2", "gt_box3"]
map1_out_column_names = ["image", "annotation", "size"]
map2_in_column_names = ["annotation", "size"]
map2_out_column_names = ["annotation", "bbox1", "bbox2", "bbox3",
"gt_box1", "gt_box2", "gt_box3"]
ds = de.GeneratorDataset(yolo_dataset, column_names=dataset_column_names, sampler=distributed_sampler,
python_multiprocessing=True, num_parallel_workers=min(4, num_parallel_workers))
ds = ds.map(operations=multi_scale_trans, input_columns=dataset_column_names,
output_columns=map1_out_column_names, column_order=map1_out_column_names,
num_parallel_workers=min(12, num_parallel_workers), python_multiprocessing=True)
ds = ds.map(operations=PreprocessTrueBox(config), input_columns=map2_in_column_names,
output_columns=map2_out_column_names, column_order=output_column_names,
num_parallel_workers=min(4, num_parallel_workers), python_multiprocessing=False)
mean = [m * 255 for m in [0.485, 0.456, 0.406]]
std = [s * 255 for s in [0.229, 0.224, 0.225]]
ds = ds.map([CV.Normalize(mean, std),
hwc_to_chw], num_parallel_workers=min(4, num_parallel_workers))
def concatenate(images):
images = np.concatenate((images[..., ::2, ::2], images[..., 1::2, ::2],
images[..., ::2, 1::2], images[..., 1::2, 1::2]), axis=0)
return images
ds = ds.map(operations=concatenate, input_columns="image", num_parallel_workers=min(4, num_parallel_workers))
ds = ds.batch(batch_size, num_parallel_workers=min(4, num_parallel_workers), drop_remainder=True)
else:
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"],
sampler=distributed_sampler)
compose_map_func = (lambda image, img_id: reshape_fn(image, img_id, config))
ds = ds.map(operations=compose_map_func, input_columns=["image", "img_id"],
output_columns=["image", "image_shape", "img_id"],
column_order=["image", "image_shape", "img_id"],
num_parallel_workers=8)
ds = ds.map(operations=hwc_to_chw, input_columns=["image"], num_parallel_workers=8)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(max_epoch)
return ds, len(yolo_dataset)
Network design module:
BackBone is divided into two parts, one on the client and the other on the server.
class YOLOv5Backbone_from(nn.Cell):
def __init__(self):
super(YOLOv5Backbone_from, self).__init__()
self.tenser_to_array = P.TupleToArray()
self.focusv2 = Focusv2(3, 32, k=3, s=1)
self.conv1 = Conv(32, 64, k=3, s=2)
self.C31 = C3(64, 64, n=1)
self.conv2 = Conv(64, 128, k=3, s=2)
def construct(self, x, input_shape):
"""construct method"""
#img_hight = P.Shape()(x)[2] * 2
#img_width = P.Shape()(x)[3] * 2
input_shape = F.shape(x)[2:4]
input_shape = F.cast(self.tenser_to_array(input_shape) * 2, ms.float32)
fcs = self.focusv2(x)
cv1 = self.conv1(fcs)
bcsp1 = self.C31(cv1)
cv2 = self.conv2(bcsp1)
return cv2, input_shape
class YOLOv5Backbone_to(nn.Cell):
def __init__(self):
super(YOLOv5Backbone_to, self).__init__()
self.C32 = C3(128, 128, n=3)
self.conv3 = Conv(128, 256, k=3, s=2)
self.C33 = C3(256, 256, n=3)
self.conv4 = Conv(256, 512, k=3, s=2)
self.spp = SPP(512, 512, k=[5, 9, 13])
self.C34 = C3(512, 512, n=1, shortcut=False)
def construct(self, cv2):
"""construct method"""
bcsp2 = self.C32(cv2)
cv3 = self.conv3(bcsp2)
bcsp3 = self.C33(cv3)
cv4 = self.conv4(bcsp3)
spp1 = self.spp(cv4)
bcsp4 = self.C34(spp1)
return bcsp2, bcsp3, bcsp4
Overall network architecture of the server:
class YOLOV5s(nn.Cell):
"""
YOLOV5 network.
Args:
is_training: Bool. Whether train or not.
Returns:
Cell, cell instance of YOLOV5 neural network.
Examples:
YOLOV5s(True)
"""
def __init__(self, is_training):
super(YOLOV5s, self).__init__()
self.config = ConfigYOLOV5()
# YOLOv5 network
self.feature_map = YOLOv5(backbone=YOLOv5Backbone_to(),
out_channel=self.config.out_channel)
# prediction on the default anchor boxes
self.detect_1 = DetectionBlock('l', is_training=is_training)
self.detect_2 = DetectionBlock('m', is_training=is_training)
self.detect_3 = DetectionBlock('s', is_training=is_training)
def construct(self, x, img_hight, img_width, input_shape):
small_object_output, medium_object_output, big_object_output = self.feature_map(x, img_hight, img_width)
output_big = self.detect_1(big_object_output, input_shape)
output_me = self.detect_2(medium_object_output, input_shape)
output_small = self.detect_3(small_object_output, input_shape)
# big is the final output which has smallest feature map
return output_big, output_me, output_small
class YOLOv5(nn.Cell):
def __init__(self, backbone, out_channel):
super(YOLOv5, self).__init__()
self.out_channel = out_channel
self.backbone = backbone
#print("self.backbone: ", self.backbone)
self.conv1 = Conv(512, 256, k=1, s=1) # 10
self.C31 = C3(512, 256, n=1, shortcut=False) # 11
self.conv2 = Conv(256, 128, k=1, s=1)
self.C32 = C3(256, 128, n=1, shortcut=False) # 13
self.conv3 = Conv(128, 128, k=3, s=2)
self.C33 = C3(256, 256, n=1, shortcut=False) # 15
self.conv4 = Conv(256, 256, k=3, s=2)
self.C34 = C3(512, 512, n=1, shortcut=False) # 17
self.backblock1 = YoloBlock(128, 255)
self.backblock2 = YoloBlock(256, 255)
self.backblock3 = YoloBlock(512, 255)
self.concat = P.Concat(axis=1)
def construct(self, x, img_hight, img_width):
"""
input_shape of x is (batch_size, 3, h, w)
feature_map1 is (batch_size, backbone_shape[2], h/8, w/8)
feature_map2 is (batch_size, backbone_shape[3], h/16, w/16)
feature_map3 is (batch_size, backbone_shape[4], h/32, w/32)
"""
#img_hight = P.Shape()(x)[2] * 2
#img_width = P.Shape()(x)[3] * 2
backbone4, backbone6, backbone9 = self.backbone(x)
cv1 = self.conv1(backbone9) # 10
ups1 = P.ResizeNearestNeighbor((img_hight / 16, img_width / 16))(cv1)
concat1 = self.concat((ups1, backbone6))
bcsp1 = self.C31(concat1) # 13
cv2 = self.conv2(bcsp1)
ups2 = P.ResizeNearestNeighbor((img_hight / 8, img_width / 8))(cv2) # 15
concat2 = self.concat((ups2, backbone4))
bcsp2 = self.C32(concat2) # 17
cv3 = self.conv3(bcsp2)
concat3 = self.concat((cv3, cv2))
bcsp3 = self.C33(concat3) # 20
cv4 = self.conv4(bcsp3)
concat4 = self.concat((cv4, cv1))
bcsp4 = self.C34(concat4) # 23
small_object_output = self.backblock1(bcsp2) # h/8, w/8
medium_object_output = self.backblock2(bcsp3) # h/16, w/16
big_object_output = self.backblock3(bcsp4) # h/32, w/32
return small_object_output, medium_object_output, big_object_output
Data privacy protection design module:
def encode_1b(x):
x[(x <= 0)] = 0
x[(x > 0)] = 1
return x
def randomize_1b(bit_tensor, epsilon):
"""
The default unary encoding method is symmetric.
"""
#assert isinstance(bit_tensor, tensor), 'the type of input data is not matched with the expected type(tensor)'
return symmetric_tensor_encoding_1b(bit_tensor, epsilon)
def symmetric_tensor_encoding_1b(bit_tensor, epsilon):
p = mnp.exp(epsilon / 2) / (mnp.exp(epsilon / 2) + 1)
q = 1 / (mnp.exp(epsilon / 2) + 1)
return produce_random_response_1b(bit_tensor, p, q)
def produce_random_response_1b(bit_tensor, p, q=None):
"""
Implements random response as the perturbation method.
when using torch tensor, we use Uniform Distribution to create Binomial Distribution
because torch have not binomial function
"""
q = 1 - p if q is None else q
uniformreal = mindspore.ops.UniformReal(seed=2)
binomial = uniformreal(bit_tensor.shape)
zeroslike = mindspore.ops.ZerosLike()
oneslike = mindspore.ops.OnesLike()
p_binomial = mnp.where(binomial > q, oneslike(bit_tensor), zeroslike(bit_tensor))
q_binomial = mnp.where(binomial <= q, oneslike(bit_tensor), zeroslike(bit_tensor))
return mnp.where(bit_tensor == 1, p_binomial, q_binomial)
Loss function module:
class YoloWithLossCell(nn.Cell):
"""YOLOV5 loss."""
def __init__(self, network):
super(YoloWithLossCell, self).__init__()
self.yolo_network = network
self.config = ConfigYOLOV5()
self.loss_big = YoloLossBlock('l', self.config)
self.loss_me = YoloLossBlock('m', self.config)
self.loss_small = YoloLossBlock('s', self.config)
def construct(self, x, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2, img_hight, img_width, input_shape):
yolo_out = self.yolo_network(x, img_hight, img_width, input_shape)
loss_l = self.loss_big(*yolo_out[0], y_true_0, gt_0, input_shape)
loss_m = self.loss_me(*yolo_out[1], y_true_1, gt_1, input_shape)
loss_s = self.loss_small(*yolo_out[2], y_true_2, gt_2, input_shape)
return loss_l + loss_m + loss_s * 0.2
class TrainingWrapper(nn.Cell):
"""Training wrapper."""
def __init__(self, network, optimizer, sens=1.0):
super(TrainingWrapper, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = None
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
if self.reducer_flag:
mean = context.get_auto_parallel_context("gradients_mean")
if auto_parallel_context().get_device_num_is_set():
degree = context.get_auto_parallel_context("device_num")
else:
degree = get_group_size()
self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, *args):
weights = self.weights
loss = self.network(*args)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*args, sens)
if self.reducer_flag:
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))
class Giou(nn.Cell):
"""Calculating giou"""
def __init__(self):
super(Giou, self).__init__()
self.cast = P.Cast()
self.reshape = P.Reshape()
self.min = P.Minimum()
self.max = P.Maximum()
self.concat = P.Concat(axis=1)
self.mean = P.ReduceMean()
self.div = P.RealDiv()
self.eps = 0.000001
def construct(self, box_p, box_gt):
"""construct method"""
box_p_area = (box_p[..., 2:3] - box_p[..., 0:1]) * (box_p[..., 3:4] - box_p[..., 1:2])
box_gt_area = (box_gt[..., 2:3] - box_gt[..., 0:1]) * (box_gt[..., 3:4] - box_gt[..., 1:2])
x_1 = self.max(box_p[..., 0:1], box_gt[..., 0:1])
x_2 = self.min(box_p[..., 2:3], box_gt[..., 2:3])
y_1 = self.max(box_p[..., 1:2], box_gt[..., 1:2])
y_2 = self.min(box_p[..., 3:4], box_gt[..., 3:4])
intersection = (y_2 - y_1) * (x_2 - x_1)
xc_1 = self.min(box_p[..., 0:1], box_gt[..., 0:1])
xc_2 = self.max(box_p[..., 2:3], box_gt[..., 2:3])
yc_1 = self.min(box_p[..., 1:2], box_gt[..., 1:2])
yc_2 = self.max(box_p[..., 3:4], box_gt[..., 3:4])
c_area = (xc_2 - xc_1) * (yc_2 - yc_1)
union = box_p_area + box_gt_area - intersection
union = union + self.eps
c_area = c_area + self.eps
iou = self.div(self.cast(intersection, ms.float32), self.cast(union, ms.float32))
res_mid0 = c_area - union
res_mid1 = self.div(self.cast(res_mid0, ms.float32), self.cast(c_area, ms.float32))
giou = iou - res_mid1
giou = C.clip_by_value(giou, -1.0, 1.0)
return giou
class Iou(nn.Cell):
"""Calculate the iou of boxes"""
def __init__(self):
super(Iou, self).__init__()
self.min = P.Minimum()
self.max = P.Maximum()
def construct(self, box1, box2):
"""
box1: pred_box [batch, gx, gy, anchors, 1, 4] ->4: [x_center, y_center, w, h]
box2: gt_box [batch, 1, 1, 1, maxbox, 4]
convert to topLeft and rightDown
"""
box1_xy = box1[:, :, :, :, :, :2]
box1_wh = box1[:, :, :, :, :, 2:4]
box1_mins = box1_xy - box1_wh / F.scalar_to_array(2.0) # topLeft
box1_maxs = box1_xy + box1_wh / F.scalar_to_array(2.0) # rightDown
box2_xy = box2[:, :, :, :, :, :2]
box2_wh = box2[:, :, :, :, :, 2:4]
box2_mins = box2_xy - box2_wh / F.scalar_to_array(2.0)
box2_maxs = box2_xy + box2_wh / F.scalar_to_array(2.0)
intersect_mins = self.max(box1_mins, box2_mins)
intersect_maxs = self.min(box1_maxs, box2_maxs)
intersect_wh = self.max(intersect_maxs - intersect_mins, F.scalar_to_array(0.0))
# P.squeeze: for effiecient slice
intersect_area = P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 0:1]) * \
P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 1:2])
box1_area = P.Squeeze(-1)(box1_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box1_wh[:, :, :, :, :, 1:2])
box2_area = P.Squeeze(-1)(box2_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box2_wh[:, :, :, :, :, 1:2])
iou = intersect_area / (box1_area + box2_area - intersect_area)
# iou : [batch, gx, gy, anchors, maxboxes]
return iou
class YoloLossBlock(nn.Cell):
"""
Loss block cell of YOLOV5 network.
"""
def __init__(self, scale, config=ConfigYOLOV5()):
super(YoloLossBlock, self).__init__()
self.config = config
if scale == 's':
# anchor mask
idx = (0, 1, 2)
elif scale == 'm':
idx = (3, 4, 5)
elif scale == 'l':
idx = (6, 7, 8)
else:
raise KeyError("Invalid scale value for DetectionBlock")
self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32)
self.ignore_threshold = Tensor(self.config.ignore_threshold, ms.float32)
self.concat = P.Concat(axis=-1)
self.iou = Iou()
self.reduce_max = P.ReduceMax(keep_dims=False)
self.confidence_loss = ConfidenceLoss()
self.class_loss = ClassLoss()
self.reduce_sum = P.ReduceSum()
self.giou = Giou()
def construct(self, prediction, pred_xy, pred_wh, y_true, gt_box, input_shape):
"""
prediction : origin output from yolo
pred_xy: (sigmoid(xy)+grid)/grid_size
pred_wh: (exp(wh)*anchors)/input_shape
y_true : after normalize
gt_box: [batch, maxboxes, xyhw] after normalize
"""
object_mask = y_true[:, :, :, :, 4:5]
class_probs = y_true[:, :, :, :, 5:]
true_boxes = y_true[:, :, :, :, :4]
grid_shape = P.Shape()(prediction)[1:3]
grid_shape = P.Cast()(F.tuple_to_array(grid_shape[::-1]), ms.float32)
pred_boxes = self.concat((pred_xy, pred_wh))
true_wh = y_true[:, :, :, :, 2:4]
true_wh = P.Select()(P.Equal()(true_wh, 0.0),
P.Fill()(P.DType()(true_wh),
P.Shape()(true_wh), 1.0),
true_wh)
true_wh = P.Log()(true_wh / self.anchors * input_shape)
# 2-w*h for large picture, use small scale, since small obj need more precise
box_loss_scale = 2 - y_true[:, :, :, :, 2:3] * y_true[:, :, :, :, 3:4]
gt_shape = P.Shape()(gt_box)
gt_box = P.Reshape()(gt_box, (gt_shape[0], 1, 1, 1, gt_shape[1], gt_shape[2]))
# add one more dimension for broadcast
iou = self.iou(P.ExpandDims()(pred_boxes, -2), gt_box)
# gt_box is x,y,h,w after normalize
# [batch, grid[0], grid[1], num_anchor, num_gt]
best_iou = self.reduce_max(iou, -1)
# [batch, grid[0], grid[1], num_anchor]
# ignore_mask IOU too small
ignore_mask = best_iou < self.ignore_threshold
ignore_mask = P.Cast()(ignore_mask, ms.float32)
ignore_mask = P.ExpandDims()(ignore_mask, -1)
# ignore_mask backpro will cause a lot maximunGrad and minimumGrad time consume.
# so we turn off its gradient
ignore_mask = F.stop_gradient(ignore_mask)
confidence_loss = self.confidence_loss(object_mask, prediction[:, :, :, :, 4:5], ignore_mask)
class_loss = self.class_loss(object_mask, prediction[:, :, :, :, 5:], class_probs)
object_mask_me = P.Reshape()(object_mask, (-1, 1)) # [8, 72, 72, 3, 1]
box_loss_scale_me = P.Reshape()(box_loss_scale, (-1, 1))
pred_boxes_me = xywh2x1y1x2y2(pred_boxes)
pred_boxes_me = P.Reshape()(pred_boxes_me, (-1, 4))
true_boxes_me = xywh2x1y1x2y2(true_boxes)
true_boxes_me = P.Reshape()(true_boxes_me, (-1, 4))
ciou = self.giou(pred_boxes_me, true_boxes_me)
ciou_loss = object_mask_me * box_loss_scale_me * (1 - ciou)
ciou_loss_me = self.reduce_sum(ciou_loss, ())
loss = ciou_loss_me * 4 + confidence_loss + class_loss
batch_size = P.Shape()(prediction)[0]
return loss / batch_size
Trainer design module:
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
"""Linear learning rate."""
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
lr = float(init_lr) + lr_inc * current_step
return lr
def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
"""Warmup step learning rate."""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
milestones = lr_epochs
milestones_steps = []
for milestone in milestones:
milestones_step = milestone * steps_per_epoch
milestones_steps.append(milestones_step)
lr_each_step = []
lr = base_lr
milestones_steps_counter = Counter(milestones_steps)
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = lr * gamma**milestones_steps_counter[i]
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
lr_epochs = []
for i in range(1, max_epoch):
if i % epoch_size == 0:
lr_epochs.append(i)
return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
"""Cosine annealing learning rate."""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
lr_each_step = []
for i in range(total_steps):
last_epoch = i // steps_per_epoch
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def warmup_cosine_annealing_lr_V2(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
"""Cosine annealing learning rate V2."""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
last_lr = 0
last_epoch_V1 = 0
T_max_V2 = int(max_epoch*1/3)
lr_each_step = []
for i in range(total_steps):
last_epoch = i // steps_per_epoch
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
if i < total_steps*2/3:
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2
last_lr = lr
last_epoch_V1 = last_epoch
else:
base_lr = last_lr
last_epoch = last_epoch-last_epoch_V1
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max_V2)) / 2
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def warmup_cosine_annealing_lr_sample(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
"""Warmup cosine annealing learning rate."""
start_sample_epoch = 60
step_sample = 2
tobe_sampled_epoch = 60
end_sampled_epoch = start_sample_epoch + step_sample*tobe_sampled_epoch
max_sampled_epoch = max_epoch+tobe_sampled_epoch
T_max = max_sampled_epoch
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
total_sampled_steps = int(max_sampled_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
lr_each_step = []
for i in range(total_sampled_steps):
last_epoch = i // steps_per_epoch
if last_epoch in range(start_sample_epoch, end_sampled_epoch, step_sample):
continue
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2
lr_each_step.append(lr)
assert total_steps == len(lr_each_step)
return np.array(lr_each_step).astype(np.float32)
def get_lr(args):
"""generate learning rate."""
if args.lr_scheduler == 'exponential':
lr = warmup_step_lr(args.lr,
args.lr_epochs,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
gamma=args.lr_gamma,
)
elif args.lr_scheduler == 'cosine_annealing':
lr = warmup_cosine_annealing_lr(args.lr, args.steps_per_epoch, args.warmup_epochs, args.max_epoch, args.T_max, args.eta_min)
elif args.lr_scheduler == 'cosine_annealing_V2':
lr = warmup_cosine_annealing_lr_V2(args.lr, args.steps_per_epoch, args.warmup_epochs, args.max_epoch, args.T_max, args.eta_min)
elif args.lr_scheduler == 'cosine_annealing_sample':
lr = warmup_cosine_annealing_lr_sample(args.lr, args.steps_per_epoch, args.warmup_epochs, args.max_epoch, args.T_max, args.eta_min)
else:
raise NotImplementedError(args.lr_scheduler)
return lr
06 Conclusion and Outlook
In this paper, a novel edge-cloud collaborative training method for privacy protection is proposed. Different from previous methods that require frequent communication between edge devices and cloud devices, MistNet only needs to upload intermediate features from the edge to the cloud once during training, significantly reducing the communication volume transmitted between the edge and cloud. By quantifying, adding noise to, compressing and disturbing the representation data, the method presented in this paper makes it more difficult to infer the original data from the representation data on the cloud, thereby increasing the level of privacy protection for the data.
In addition, the first several layers of the model are used as a feature extractor after the pre-trained model is segmented, thereby reducing computing workloads on the client. The MistNet algorithm further alleviates the defects of federated learning algorithms such as FedAvg. Nevertheless, new Federated learning-based algorithms that require low communication volume, strong privacy protection, and minimal edge computing workloads, are certainly worth further exploration and research.