# 缺失API处理策略 [](https://gitee.com/mindspore/docs/blob/r2.4.1/docs/mindspore/source_zh_cn/migration_guide/missing_api_processing_policy.md) 有以下方法来处理缺失API的情况。 ## 1. 等价替换 在有些场景下API的功能是可以等价替换的,比如: - Squeeze,Flatten,ExpandDims等没有实际的计算,只是改变Tensor shape的API均可以用Reshape代替; - AdaptiveAvgPool,AdaptiveMaxPool在输出的shape是1时,与ReduceMean,ReduceMax在设置keep_dims=True时是等价的; - MaxPool和MaxPoolWithArgmax在不使用indices的情况是等价的; - Sort和在全排序场景下的TopK是等价的。 ## 2. 使用已有API包装等价功能逻辑 对于一些缺失的API,可以基于MindSpore已有的API实现等价功能。下面举一个`sigmoid focal loss`的例子: 先来分析一下这个API的算法基础。 Focal Loss[1]是一种用来处理单阶段目标检测器训练过程中出现的正负样本、难易样本不平衡问题的方法。 常用的sigmoid focal loss的API接口是MMDetection的实现,PyTorch实现代码参考下方左侧。 参考API映射表,可以看到PyTorch代码中使用的API在MindSpore上都有对应实现,没有缺失。 根据PyTorch的实现,MindSpore的版本参考下方右侧。
| PyTorch | MindSpore |
```python
import torch.nn.functional as F
def reduce_loss(loss, reduction):
"""Reduce loss as specified.
Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are "none", "mean" and "sum".
Return:
Tensor: Reduced loss tensor.
"""
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
elif reduction_enum == 2:
return loss.sum()
def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
if weight is not None:
loss = loss * weight
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
loss = reduce_loss(loss, reduction)
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
loss = loss.sum() / avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss
def py_sigmoid_focal_loss(pred, target, weight=None, gamma=2.0, alpha=0.25,
reduction='mean', avg_factor=None):
"""PyTorch version of `Focal Loss
|
```python
import mindspore as ms
from mindspore import nn, ops
class SigmoidFoaclLoss(nn.Cell):
def __init__(self, weight=None, gamma=2.0, alpha=0.25, reduction='mean', avg_factor=None):
super(SigmoidFoaclLoss, self).__init__()
self.sigmoid = ops.Sigmoid()
self.alpha = alpha
self.gamma = gamma
self.weight = ms.Tensor(weight) if weight is not None else weight
self.reduction = reduction
self.avg_factor = avg_factor
self.binary_cross_entropy_with_logits = nn.BCEWithLogitsLoss(reduction="none")
self.is_weight = (weight is not None)
def reduce_loss(self, loss):
"""Reduce loss as specified.
Args:
loss (Tensor): Elementwise loss tensor.
Return:
Tensor: Reduced loss tensor.
"""
if self.reduction == "mean":
return loss.mean()
elif self.reduction == "sum":
return loss.sum()
return loss
def weight_reduce_loss(self, loss):
# if avg_factor is not specified, just reduce the loss
if self.avg_factor is None:
loss = self.reduce_loss(loss)
else:
# if reduction is mean, then average the loss by avg_factor
if self.reduction == 'mean':
loss = loss.sum() / self.avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif self.reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss
def construct(self, pred, target):
pred_sigmoid = self.sigmoid(pred)
target = ops.cast(target, pred.dtype)
pt = (1 - pred_sigmoid) * target +
pred_sigmoid * (1 - target)
focal_weight = (self.alpha * target +
(1 - self.alpha) *
(1 - target)) * ops.pow(pt, self.gamma)
loss = self.binary_cross_entropy_with_logits(pred, target) * focal_weight
if self.is_weight:
weight = self.weight
if self.weight.shape != loss.shape:
if self.weight.shape[0] == loss.shape[0]:
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = self.weight.view(-1, 1)
elif self.weight.size == loss.size:
# Sometimes, weight per anchor per class is also needed.
# e.g. in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
weight = self.weight.view(loss.shape[0], -1)
elif self.weight.ndim != loss.ndim:
raise ValueError(f"weight shape {self.weight.shape} is not match to loss shape {loss.shape}")
loss = loss * weight
loss = self.weight_reduce_loss(loss)
return loss
```
|