mindsponge.metrics.local_distance_difference_test

mindsponge.metrics.local_distance_difference_test(predicted_points, true_points, true_points_mask, cutoff=15, per_residue=False)[source]

Compute true and predicted distance matrices for \(C\alpha\). First calculate the distance matrix of true and predicted \(C\alpha\) atoms \(D = (((x[None,:] - x[:,None])^2).sum(-1))^{0.5}\) then compute the rate that difference is smaller than fixed value: \(lddt = (rate(abs(D_{true} - D_{pred}) < 0.5) + rate(abs(D_{true} - D_{pred}) < 1.0) + rate(abs(D_{true} - D_{pred}) < 2.0) + rate(abs(D_{true} - D_{pred}) < 4.0))/4\) Jumper et al. (2021) Suppl. Alg. 29 “predictPerResidueLDDT_Ca”.

Parameters
  • predicted_points (Tensor) – The prediction Ca atoms position tensor of shape \((1, N_{res}, 3)\) with \(N_{res}\) the number of residues in protein.

  • true_points (Tensor) – The ground truth Ca atoms position tensor of shape \((1, N_{res}, 3)\)

  • true_points_mask (Tensor) – The binary mask for predicted_points of shape \((1, N_{res}, 1)\)

  • cutoff (float) – The cutoff value for lddt to stop gradient, Default: 15.

  • per_residue (bool) – The indicator if local distance difference is averaged, set True to return local distance difference per residue. Default: False.

Returns

  • score (Tensor) - Local distance difference score, the shape is \((1,)\) if per_residue set False, \((1, N_{res})\) otherwise.

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> np.random.seed(0)
>>> from mindsponge.metrics import local_distance_difference_test
>>> from mindspore import dtype as mstype
>>> from mindspore import Tensor
>>> predicted_points = Tensor(np.random.rand(1, 256, 3)).astype(mstype.float32)
>>> true_points = Tensor(np.random.rand(1, 256, 3)).astype(mstype.float32)
>>> true_points_mask = Tensor(np.random.rand(1, 256, 1)).astype(mstype.float32)
>>> lddt = local_distance_difference_test(predicted_points, true_points, true_points_mask,
...                                       cutoff=15, per_residue=False)
>>> print(lddt)
[0.9554313]