# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""data transform MSA TEMPLATE"""
import numpy as np
import mindsponge.common.geometry as geometry
from mindsponge.common.residue_constants import chi_angles_mask, chi_pi_periodic, restype_1to3, chi_angles_atoms, \
    atom_order, residue_atom_renaming_swaps, restype_3to1, MAP_HHBLITS_AATYPE_TO_OUR_AATYPE, restype_order, \
    restypes, restype_name_to_atom14_names, atom_types, residue_atoms, STANDARD_ATOM_MASK, restypes_with_x_and_gap, \

MS_MIN32 = -2147483648
MS_MAX32 = 2147483647

def one_hot(depth, indices):
    """one hot compute"""
    res = np.eye(depth)[indices.reshape(-1)]
    return res.reshape(list(indices.shape) + [depth])

def correct_msa_restypes(msa, deletion_matrix=None, is_evogen=False):
    """Correct MSA restype to have the same order as residue_constants."""
    new_order = np.array(new_order_list, dtype=msa.dtype)
    msa = new_order[msa]
    if is_evogen:
        msa_input = np.concatenate((msa, deletion_matrix), axis=-1).astype(np.int32)
        result = msa, msa_input
        result = msa
    return result

def randomly_replace_msa_with_unknown(msa, aatype, replace_proportion):
    """Replace a proportion of the MSA with 'X'."""
    msa_mask = np.random.uniform(size=msa.shape, low=0, high=1) < replace_proportion
    x_idx = 20
    gap_idx = 21
    msa_mask = np.logical_and(msa_mask, msa != gap_idx)
    msa = np.where(msa_mask, np.ones_like(msa) * x_idx, msa)
    aatype_mask = np.random.uniform(size=aatype.shape, low=0, high=1) < replace_proportion
    aatype = np.where(aatype_mask, np.ones_like(aatype) * x_idx, aatype)
    return msa, aatype

def fix_templates_aatype(template_aatype):
    """Fixes aatype encoding of templates."""
    # Map one-hot to indices.
    template_aatype = np.argmax(template_aatype, axis=-1).astype(np.int32)
    # Map hhsearch-aatype to our aatype.
    new_order = np.array(new_order_list, np.int32)
    template_aatype = new_order[template_aatype]
    return template_aatype

def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
    """compute pseudo beta features from atom positions"""
    is_gly = np.equal(aatype, restype_order['G'])
    ca_idx = atom_order['CA']
    cb_idx = atom_order['CB']
    pseudo_beta = np.where(
        np.tile(is_gly[..., None].astype("int32"), [1] * len(is_gly.shape) + [3]).astype("bool"),
        all_atom_positions[..., ca_idx, :],
        all_atom_positions[..., cb_idx, :])
    if all_atom_masks is not None:
        pseudo_beta_mask = np.where(is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx])
        pseudo_beta_mask = pseudo_beta_mask.astype(np.float32)
        return pseudo_beta, pseudo_beta_mask
    return pseudo_beta

def make_atom14_masks(aatype):
    """create atom 14 position features from aatype"""
    rt_atom14_to_atom37 = []
    rt_atom37_to_atom14 = []
    rt_atom14_mask = []

    for restype in restypes:
        atom_names = restype_name_to_atom14_names.get(restype_1to3.get(restype))

        rt_atom14_to_atom37.append([(atom_order[name] if name else 0) for name in atom_names])

        atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
        rt_atom37_to_atom14.append([(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
                                    for name in atom_types])

        rt_atom14_mask.append([(1. if name else 0.) for name in atom_names])

    # Add dummy mapping for restype 'UNK'
    rt_atom14_to_atom37.append([0] * 14)
    rt_atom37_to_atom14.append([0] * 37)
    rt_atom14_mask.append([0.] * 14)

    rt_atom14_to_atom37 = np.array(rt_atom14_to_atom37, np.int32)
    rt_atom37_to_atom14 = np.array(rt_atom37_to_atom14, np.int32)
    rt_atom14_mask = np.array(rt_atom14_mask, np.float32)

    ri_atom14_to_atom37 = rt_atom14_to_atom37[aatype]
    ri_atom14_mask = rt_atom14_mask[aatype]

    atom14_atom_exists = ri_atom14_mask
    ri_atom14_to_atom37 = ri_atom14_to_atom37

    # create the gather indices for mapping back
    ri_atom37_to_atom14 = rt_atom37_to_atom14[aatype]
    ri_atom37_to_atom14 = ri_atom37_to_atom14

    # create the corresponding mask
    restype_atom37_mask = np.zeros([21, 37], np.float32)
    for restype, restype_letter in enumerate(restypes):
        restype_name = restype_1to3.get(restype_letter)
        atom_names = residue_atoms.get(restype_name)
        for atom_name in atom_names:
            atom_type = atom_order[atom_name]
            restype_atom37_mask[restype, atom_type] = 1

    atom37_atom_exists = restype_atom37_mask[aatype]
    res = [atom14_atom_exists, ri_atom14_to_atom37, ri_atom37_to_atom14, atom37_atom_exists]
    return res

def block_delete_msa_indices(msa, msa_fraction_per_block, randomize_num_blocks, num_blocks):
    """Sample MSA by deleting contiguous blocks.

    Jumper et al. (2021) Suppl. Alg. 1 "MSABlockDeletion"

    protein: batch dict containing the msa
    config: ConfigDict with parameters

    updated protein

    num_seq = msa.shape[0]
    block_num_seq = np.floor(num_seq * msa_fraction_per_block).astype(np.int32)

    if randomize_num_blocks:
        nb = int(np.random.uniform(0, num_blocks + 1))
        nb = num_blocks
    del_block_starts = np.random.uniform(0, num_seq, nb).astype(np.int32)
    del_blocks = del_block_starts[:, None] + np.array([_ for _ in range(block_num_seq)]).astype(np.int32)
    del_blocks = np.clip(del_blocks, 0, num_seq - 1)
    del_indices = np.unique(np.sort(np.reshape(del_blocks, (-1,))))

    # Make sure we keep the original sequence
    keep_indices = np.setdiff1d(np.array([_ for _ in range(1, num_seq)]),
    keep_indices = np.concatenate([[0], keep_indices], axis=0)
    keep_indices = [int(x) for x in keep_indices]
    return keep_indices

def sample_msa(msa, max_seq):
    """Sample MSA randomly, remaining sequences are stored as `extra_*`."""
    num_seq = msa.shape[0]

    shuffled = list(range(1, num_seq))
    shuffled.insert(0, 0)
    index_order = np.array(shuffled, np.int32)
    num_sel = min(max_seq, num_seq)

    sel_seq = index_order[:num_sel]
    not_sel_seq = index_order[num_sel:]
    is_sel = num_seq - num_sel
    return is_sel, not_sel_seq, sel_seq

def shape_list(x):
    """get the list of dimensions of an array"""
    x = np.array(x)
    if x.ndim is None:
        return x.shape

    static = x.shape
    ret = []
    for _, dimension in enumerate(static):
    return ret

def shaped_categorical(probability):
    """get categorical shape"""
    ds = shape_list(probability)
    num_classes = ds[-1]
    flat_probs = np.reshape(probability, (-1, num_classes))
    numbers = list(range(num_classes))
    res = []
    for flat_prob in flat_probs:
        res.append(np.random.choice(numbers, p=flat_prob))
    return np.reshape(np.array(res, np.int32), ds[:-1])

def make_masked_msa(msa, hhblits_profile, uniform_prob, profile_prob, same_prob, replace_fraction, residue_index=None,
                    msa_mask=None, is_evogen=False):
    """create masked msa for BERT on raw MSA features"""

    random_aatype = np.array([0.05] * 20 + [0., 0.], dtype=np.float32)

    probability = uniform_prob * random_aatype + profile_prob * hhblits_profile + same_prob * one_hot(22, msa)

    pad_shapes = [[0, 0] for _ in range(len(probability.shape))]
    pad_shapes[-1][1] = 1
    mask_prob = 1. - profile_prob - same_prob - uniform_prob

    probability = np.pad(probability, pad_shapes, constant_values=(mask_prob,))

    masked_aatype = np.random.uniform(size=msa.shape, low=0, high=1) < replace_fraction

    bert_msa = shaped_categorical(probability)
    bert_msa = np.where(masked_aatype, bert_msa, msa)

    bert_mask = masked_aatype.astype(np.int32)
    true_msa = msa
    msa = bert_msa
    if is_evogen:
        additional_input = np.concatenate((bert_msa[0][:, None], np.asarray(residue_index)[:, None],
                                           msa_mask[0][:, None],
                                           bert_mask[0][:, None]),
        make_masked_msa_result = bert_mask, true_msa, msa, additional_input

        make_masked_msa_result = bert_mask, true_msa, msa
    return make_masked_msa_result

def nearest_neighbor_clusters(msa_mask, msa, extra_msa_mask, extra_msa, gap_agreement_weight=0.):
    """Assign each extra MSA sequence to its nearest neighbor in sampled MSA."""

    # Determine how much weight we assign to each agreement.  In theory, we could
    # use a full blosum matrix here, but right now let's just down-weight gap
    # agreement because it could be spurious.
    # Never put weight on agreeing on BERT mask
    weights = np.concatenate([np.ones(21), gap_agreement_weight * np.ones(1), np.zeros(1)], 0)

    # Make agreement score as weighted Hamming distance
    sample_one_hot = msa_mask[:, :, None] * one_hot(23, msa)
    num_seq, num_res, _ = sample_one_hot.shape

    array_extra_msa_mask = extra_msa_mask
    if array_extra_msa_mask.any():
        extra_one_hot = extra_msa_mask[:, :, None] * one_hot(23, extra_msa)
        extra_num_seq, _, _ = extra_one_hot.shape

        agreement = np.matmul(
            np.reshape(extra_one_hot, [extra_num_seq, num_res * 23]),
            np.reshape(sample_one_hot * weights, [num_seq, num_res * 23]).T)
        # Assign each sequence in the extra sequences to the closest MSA sample
        extra_cluster_assignment = np.argmax(agreement, axis=1)
        extra_cluster_assignment = np.array([])
    return extra_cluster_assignment

def summarize_clusters(msa, msa_mask, extra_cluster_assignment, extra_msa_mask, extra_msa, extra_deletion_matrix,
    """Produce profile and deletion_matrix_mean within each cluster."""
    num_seq = msa.shape[0]

    def csum(x):
        result = []
        for i in range(num_seq):
            result.append(np.sum(x[np.where(extra_cluster_assignment == i)], axis=0))
        return np.array(result)

    mask = extra_msa_mask
    mask_counts = 1e-6 + msa_mask + csum(mask)  # Include center

    msa_sum = csum(mask[:, :, None] * one_hot(23, extra_msa))
    msa_sum += one_hot(23, msa)  # Original sequence
    cluster_profile = msa_sum / mask_counts[:, :, None]

    del msa_sum

    del_sum = csum(mask * extra_deletion_matrix)
    del_sum += deletion_matrix  # Original sequence
    cluster_deletion_mean = del_sum / mask_counts
    del del_sum

    return cluster_profile, cluster_deletion_mean

def crop_extra_msa(extra_msa, max_extra_msa):
    """MSA features are cropped so only `max_extra_msa` sequences are kept."""
    if extra_msa.any():
        num_seq = extra_msa.shape[0]
        num_sel = np.minimum(max_extra_msa, num_seq)
        shuffled = list(range(num_seq))
        select_indices = shuffled[:num_sel]
        return select_indices
    return None

def make_msa_feat(between_segment_residues, aatype, msa, deletion_matrix, cluster_deletion_mean, cluster_profile,
    """Create and concatenate MSA features."""
    # Whether there is a domain break. Always zero for chains, but keeping
    # for compatibility with domain datasets.
    has_break = np.clip(between_segment_residues.astype(np.float32), np.array(0), np.array(1))
    aatype_1hot = one_hot(21, aatype)

    target_feat = [np.expand_dims(has_break, axis=-1), aatype_1hot]

    msa_1hot = one_hot(23, msa)
    has_deletion = np.clip(deletion_matrix, np.array(0), np.array(1))
    deletion_value = np.arctan(deletion_matrix / 3.) * (2. / np.pi)

    msa_feat = [msa_1hot, np.expand_dims(has_deletion, axis=-1), np.expand_dims(deletion_value, axis=-1)]

    if cluster_profile is not None:
        deletion_mean_value = (np.arctan(cluster_deletion_mean / 3.) * (2. / np.pi))
        msa_feat.extend([cluster_profile, np.expand_dims(deletion_mean_value, axis=-1)])
    extra_has_deletion = None
    extra_deletion_value = None
    if extra_deletion_matrix is not None:
        extra_has_deletion = np.clip(extra_deletion_matrix, np.array(0), np.array(1))
        extra_deletion_value = np.arctan(extra_deletion_matrix / 3.) * (2. / np.pi)

    msa_feat = np.concatenate(msa_feat, axis=-1)
    target_feat = np.concatenate(target_feat, axis=-1)
    res = [extra_has_deletion, extra_deletion_value, msa_feat, target_feat]
    return res

def make_random_seed(size, seed_maker_t, low=MS_MIN32, high=MS_MAX32, random_recycle=False):
    if random_recycle:
        r = np.random.RandomState(seed_maker_t)
        return r.uniform(size=size, low=low, high=high)
    return np.random.uniform(size=size, low=low, high=high)

def random_crop_to_size(seq_length, template_mask, crop_size, max_templates,
                        subsample_templates=False, seed=0, random_recycle=False):
    """Crop randomly to `crop_size`, or keep as is if shorter than that."""
    seq_length = seq_length
    seq_length_int = int(seq_length)
    if template_mask is not None:
        num_templates = np.array(template_mask.shape[0], np.int32)
        num_templates = np.array(0, np.int32)
    num_res_crop_size = np.minimum(seq_length, crop_size)
    num_res_crop_size_int = int(num_res_crop_size)

    # Ensures that the cropping of residues and templates happens in the same way
    # across ensembling iterations.
    # Do not use for randomness that should vary in ensembling.

    if subsample_templates:
        templates_crop_start = int(make_random_seed(size=(), seed_maker_t=seed, low=0, high=num_templates + 1,
        templates_crop_start = 0

    num_templates_crop_size = np.minimum(num_templates - templates_crop_start, max_templates)
    num_templates_crop_size_int = int(num_templates_crop_size)

    num_res_crop_start = int(make_random_seed(size=(), seed_maker_t=seed, low=0,
                                              high=seq_length_int - num_res_crop_size_int + 1,

    templates_select_indices = np.argsort(make_random_seed(size=[num_templates], seed_maker_t=seed,
    res = [num_res_crop_size, num_templates_crop_size_int, num_res_crop_start, num_res_crop_size_int, \
           templates_crop_start, templates_select_indices]
    return res

[文档]def atom37_to_torsion_angles( aatype: np.ndarray, all_atom_pos: np.ndarray, all_atom_mask: np.ndarray, alt_torsions=False, ): r""" This function calculates the seven torsion angles of each residue and encodes them in sine and cosine. The order of the seven torsion angles is [pre_omega, phi, psi, chi_1, chi_2, chi_3, chi_4] Here, pre_omega represents the twist angle between a given amino acid and the previous amino acid. The phi represents twist angle between `C-CA-N-(C+1)`, psi represents twist angle between `(N-1)-C-CA-N`. Args: aatype (numpy.array): Amino acid type with shape :math:`(batch\_size, N_{res})`. all_atom_pos (numpy.array): Atom37 representation of all atomic coordinates with shape :math:`(batch\_size, N_{res}, 37, 3)`. all_atom_mask (numpy.array): Atom37 representation of the mask on all atomic coordinates with shape :math:`(batch\_size, N_{res})`. alt_torsions (bool): Indicates whether to set the sign angle of shielding torsion to zero. Default: False. Returns: Dict containing - torsion_angles_sin_cos (numpy.array), with shape :math:`(batch\_size, N_{res}, 37, 3)` where the final 2 dimensions denote sin and cos respectively. - alt_torsion_angles_sin_cos (numpy.array), same as 'torsion_angles_sin_cos', but with the angle shifted by pi for all chi angles affected by the naming ambiguities. - torsion_angles_mask (numpy.array), Mask for which chi angles are present. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` Examples: >>> import numpy as np >>> from mindsponge.data.data_transform import atom37_to_torsion_angles >>> n_res = 16 >>> bs = 1 >>> aatype = np.random.randn(bs, n_res).astype(np.int32) >>> all_atom_pos = np.random.randn(bs, n_res, 37, 3).astype(np.float32) >>> all_atom_mask = np.random.randn(bs, n_res, 37).astype(np.float32) >>> angle_label_feature = atom37_to_torsion_angles(aatype, all_atom_pos, all_atom_mask) >>> print(angle_label_feature.keys()) dict_keys(['torsion_angles_sin_cos', 'alt_torsion_angles_sin_cos', 'torsion_angles_mask']) """ true_aatype = np.minimum(aatype, 20) # get the number residue num_batch, num_res = true_aatype.shape paddings = np.zeros([num_batch, 1, 37, 3], np.float32) padding_atom_pos = np.concatenate([paddings, all_atom_pos[:, :-1, :, :]], axis=1) paddings = np.zeros([num_batch, 1, 37], np.float32) padding_atom_mask = np.concatenate([paddings, all_atom_mask[:, :-1, :]], axis=1) # compute padding atom position for omega, phi and psi omega_atom_pos_padding = np.concatenate( [padding_atom_pos[..., 1:3, :], all_atom_pos[..., 0:2, :] ], axis=-2) phi_atom_pos_padding = np.concatenate( [padding_atom_pos[..., 2:3, :], all_atom_pos[..., 0:3, :] ], axis=-2) psi_atom_pos_padding = np.concatenate( [all_atom_pos[..., 0:3, :], all_atom_pos[..., 4:5, :] ], axis=-2) # compute padding atom position mask for omega, phi and psi omega_mask_padding = (np.prod(padding_atom_mask[..., 1:3], axis=-1) * np.prod(all_atom_mask[..., 0:2], axis=-1)) phi_mask_padding = (padding_atom_mask[..., 2] * np.prod(all_atom_mask[..., 0:3], axis=-1)) psi_mask_padding = (np.prod(all_atom_mask[..., 0:3], axis=-1) * all_atom_mask[..., 4]) chi_atom_pos_indices = get_chi_atom_pos_indices() atom_pos_indices = np_gather_ops(chi_atom_pos_indices, true_aatype, 0, 0) chi_atom_pos = np_gather_ops(all_atom_pos, atom_pos_indices, -2, 2) angles_mask = list(chi_angles_mask) angles_mask.append([0.0, 0.0, 0.0, 0.0]) angles_mask = np.array(angles_mask) chis_mask = np_gather_ops(angles_mask, true_aatype, 0, 0) chi_angle_atoms_mask = np_gather_ops(all_atom_mask, atom_pos_indices, -1, 2) chi_angle_atoms_mask = np.prod(chi_angle_atoms_mask, axis=-1) chis_mask = chis_mask * chi_angle_atoms_mask.astype(np.float32) torsions_atom_pos_padding = np.concatenate( [omega_atom_pos_padding[:, :, None, :, :], phi_atom_pos_padding[:, :, None, :, :], psi_atom_pos_padding[:, :, None, :, :], chi_atom_pos ], axis=2) torsion_angles_mask_padding = np.concatenate( [omega_mask_padding[:, :, None], phi_mask_padding[:, :, None], psi_mask_padding[:, :, None], chis_mask ], axis=2) torsion_frames = geometry.rigids_from_3_points( point_on_neg_x_axis=geometry.vecs_from_tensor(torsions_atom_pos_padding[:, :, :, 1, :]), origin=geometry.vecs_from_tensor(torsions_atom_pos_padding[:, :, :, 2, :]), point_on_xy_plane=geometry.vecs_from_tensor(torsions_atom_pos_padding[:, :, :, 0, :])) inv_torsion_frames = geometry.invert_rigids(torsion_frames) vecs = geometry.vecs_from_tensor(torsions_atom_pos_padding[:, :, :, 3, :]) forth_atom_rel_pos = geometry.rigids_mul_vecs(inv_torsion_frames, vecs) torsion_angles_sin_cos = np.stack( [forth_atom_rel_pos[2], forth_atom_rel_pos[1]], axis=-1) torsion_angles_sin_cos /= np.sqrt( np.sum(np.square(torsion_angles_sin_cos), axis=-1, keepdims=True) + 1e-8) torsion_angles_sin_cos *= np.array( [1., 1., -1., 1., 1., 1., 1.])[None, None, :, None] chi_is_ambiguous = np_gather_ops( np.array(chi_pi_periodic), true_aatype) mirror_torsion_angles = np.concatenate( [np.ones([num_batch, num_res, 3]), 1.0 - 2.0 * chi_is_ambiguous], axis=-1) alt_torsion_angles_sin_cos = (torsion_angles_sin_cos * mirror_torsion_angles[:, :, :, None]) if alt_torsions: fix_torsions = np.stack([np.ones(torsion_angles_sin_cos.shape[:-1]), np.zeros(torsion_angles_sin_cos.shape[:-1])], axis=-1) torsion_angles_sin_cos = torsion_angles_sin_cos * torsion_angles_mask_padding[ ..., None] + fix_torsions * (1 - torsion_angles_mask_padding[..., None]) alt_torsion_angles_sin_cos = alt_torsion_angles_sin_cos * torsion_angles_mask_padding[ ..., None] + fix_torsions * (1 - torsion_angles_mask_padding[..., None]) return { 'torsion_angles_sin_cos': torsion_angles_sin_cos[0], # (N, 7, 2) 'alt_torsion_angles_sin_cos': alt_torsion_angles_sin_cos[0], # (N, 7, 2) 'torsion_angles_mask': torsion_angles_mask_padding[0] # (N, 7) }
[文档]def atom37_to_frames( aatype, all_atom_positions, all_atom_mask, is_affine=False ): r""" Computes the torsion angle of up to 8 rigid groups for each residue, shape is :math:`[N_{res}, 8, 12]`, where 8 is indicates that each residue can be divided into up to 8 rigid groups according to the dependence of the atom on the torsion angle, there are 1 backbone frame and 7 side-chain frames. For the meaning of 12 ,the first 9 elements are the 9 components of rotation matrix, the last 3 elements are the 3 component of translation matrix. Args: aatype(numpy.array): Amino acid sequence, :math:`[N_{res}]` . all_atom_positions(numpy.array): The coordinates of all atoms, presented as atom37, :math:`[N_{res}, 37,3]`. all_atom_mask(numpy.array): Mask of all atomic coordinates, :math:`[N_{res},37]`. is_affine(bool): Whether to perform affine, the default value is False. Returns: Dictionary, the specific content is as follows. - **rigidgroups_gt_frames** (numpy.array) - The torsion angle of the 8 rigid body groups for each residue, :math:`[N_{res}, 8, 12]`. - **rigidgroups_gt_exists** (numpy.array) - The mask of rigidgroups_gt_frames denoting whether the rigid body group exists according to the experiment, :math:`[N_{res}, 8]`. - **rigidgroups_group_exists** (numpy.array) - Mask denoting whether given group is in principle present for given amino acid type, :math:`[N_{res}, 8]` . - **rigidgroups_group_is_ambiguous** (numpy.array) - Indicates that the position is chiral symmetry, :math:`[N_{res}, 8]` . - **rigidgroups_alt_gt_frames** (numpy.array) - 8 Frames with alternative atom renaming corresponding to 'all_atom_positions' represented as flat 12 dimensional array :math:`[N_{res}, 8, 12]` . - **backbone_affine_tensor** (numpy.array) - The translation and rotation of the local coordinates of each amino acid relative to the global coordinates, :math:`[N_{res}, 7]` , for the last dimension, the first 4 elements are the affine tensor which contains the rotation information, the last 3 elements are the translations in space. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` Examples: >>> import numpy as np >>> from mindsponge.data import atom37_to_frames >>> from mindspore import dtype as mstype >>> from mindspore import Tensor >>> aatype = np.ones(193,dtype=np.int32) >>> all_atom_positions = np.ones((193,37,3),dtype=np.float32) >>> all_atom_mask = np.ones((193,37),dtype=np.int32) >>> result = atom37_to_frames(aatype,all_atom_positions,all_atom_mask) >>> for key in result.keys(): >>> print(key,result[key].shape) rigidgroups_gt_frames (193, 8, 12) rigidgroups_gt_exists (193, 8) rigidgroups_group_exists (193, 8) rigidgroups_group_is_ambiguous (193, 8) rigidgroups_alt_gt_frames (193, 8, 12) """ aatype_shape = aatype.shape flat_aatype = np.reshape(aatype, [-1]) all_atom_positions = np.reshape(all_atom_positions, [-1, 37, 3]) all_atom_mask = np.reshape(all_atom_mask, [-1, 37]) rigid_group_names_res = np.full([21, 8, 3], '', dtype=object) # group 0: backbone frame rigid_group_names_res[:, 0, :] = ['C', 'CA', 'N'] # group 3: 'psi' rigid_group_names_res[:, 3, :] = ['CA', 'C', 'O'] # group 4,5,6,7: 'chi1,2,3,4' for restype, letter in enumerate(restypes): restype_name = restype_1to3[letter] for chi_idx in range(4): if chi_angles_mask[restype][chi_idx]: atom_names = chi_angles_atoms[restype_name][chi_idx] rigid_group_names_res[restype, chi_idx + 4, :] = atom_names[1:] # create rigid group mask rigid_group_mask_res = np.zeros([21, 8], dtype=np.float32) rigid_group_mask_res[:, 0] = 1 rigid_group_mask_res[:, 3] = 1 rigid_group_mask_res[:20, 4:] = chi_angles_mask lookup_table = atom_order.copy() lookup_table[''] = 0 rigid_group_atom37_idx_restype = np.vectorize(lambda x: lookup_table[x])( rigid_group_names_res) rigid_group_atom37_idx_residx = np_gather_ops( rigid_group_atom37_idx_restype, flat_aatype) base_atom_pos = np_gather_ops( all_atom_positions, rigid_group_atom37_idx_residx, batch_dims=1) gt_frames = geometry.rigids_from_3_points( point_on_neg_x_axis=geometry.vecs_from_tensor(base_atom_pos[:, :, 0, :]), origin=geometry.vecs_from_tensor(base_atom_pos[:, :, 1, :]), point_on_xy_plane=geometry.vecs_from_tensor(base_atom_pos[:, :, 2, :])) # get the group mask group_masks = np_gather_ops(rigid_group_mask_res, flat_aatype) # get the atom mask gt_atoms_exists = np_gather_ops( all_atom_mask.astype(np.float32), rigid_group_atom37_idx_residx, batch_dims=1) gt_masks = np.min(gt_atoms_exists, axis=-1) * group_masks rotations = np.tile(np.eye(3, dtype=np.float32), [8, 1, 1]) rotations[0, 0, 0] = -1 rotations[0, 2, 2] = -1 gt_frames = geometry.rigids_mul_rots(gt_frames, geometry.rots_from_tensor(rotations, use_numpy=True)) rigid_group_is_ambiguous_res = np.zeros([21, 8], dtype=np.float32) rigid_group_rotations_res = np.tile(np.eye(3, dtype=np.float32), [21, 8, 1, 1]) for restype_name, _ in residue_atom_renaming_swaps.items(): restype = restype_order[restype_3to1[restype_name]] chi_idx = int(sum(chi_angles_mask[restype]) - 1) rigid_group_is_ambiguous_res[restype, chi_idx + 4] = 1 rigid_group_rotations_res[restype, chi_idx + 4, 1, 1] = -1 rigid_group_rotations_res[restype, chi_idx + 4, 2, 2] = -1 # Gather the ambiguity information for each residue. rigid_group_is_ambiguous_res_index = np_gather_ops( rigid_group_is_ambiguous_res, flat_aatype) rigid_group_ambiguity_rotation_res_index = np_gather_ops( rigid_group_rotations_res, flat_aatype) # Create the alternative ground truth frames. alt_gt_frames = geometry.rigids_mul_rots( gt_frames, geometry.rots_from_tensor(rigid_group_ambiguity_rotation_res_index, use_numpy=True)) gt_frames_flat12 = np.stack(list(gt_frames[0]) + list(gt_frames[1]), axis=-1) alt_gt_frames_flat12 = np.stack(list(alt_gt_frames[0]) + list(alt_gt_frames[1]), axis=-1) # reshape back to original residue layout gt_frames_flat12 = np.reshape(gt_frames_flat12, aatype_shape + (8, 12)) gt_masks = np.reshape(gt_masks, aatype_shape + (8,)) group_masks = np.reshape(group_masks, aatype_shape + (8,)) gt_frames_flat12 = np.reshape(gt_frames_flat12, aatype_shape + (8, 12)) rigid_group_is_ambiguous_res_index = np.reshape(rigid_group_is_ambiguous_res_index, aatype_shape + (8,)) alt_gt_frames_flat12 = np.reshape(alt_gt_frames_flat12, aatype_shape + (8, 12,)) if not is_affine: return { 'rigidgroups_gt_frames': gt_frames_flat12, # shape (..., 8, 12) 'rigidgroups_gt_exists': gt_masks, # shape (..., 8) 'rigidgroups_group_exists': group_masks, # shape (..., 8) 'rigidgroups_group_is_ambiguous': rigid_group_is_ambiguous_res_index, # shape (..., 8) 'rigidgroups_alt_gt_frames': alt_gt_frames_flat12, # shape (..., 8, 12) } rotation = [[gt_frames[0][0], gt_frames[0][1], gt_frames[0][2]], [gt_frames[0][3], gt_frames[0][4], gt_frames[0][5]], [gt_frames[0][6], gt_frames[0][7], gt_frames[0][8]]] translation = [gt_frames[1][0], gt_frames[1][1], gt_frames[1][2]] backbone_affine_tensor = to_tensor(rotation, translation)[:, 0, :] return { 'rigidgroups_gt_frames': gt_frames_flat12, # shape (..., 8, 12) 'rigidgroups_gt_exists': gt_masks, # shape (..., 8) 'rigidgroups_group_exists': group_masks, # shape (..., 8) 'rigidgroups_group_is_ambiguous': rigid_group_is_ambiguous_res_index, # shape (..., 8) 'rigidgroups_alt_gt_frames': alt_gt_frames_flat12, # shape (..., 8, 12) 'backbone_affine_tensor': backbone_affine_tensor, # shape (..., 7) }
def get_chi_atom_pos_indices(): """get the atom indices for computing chi angles for all residue types""" chi_atom_pos_indices = [] for residue_name in restypes: residue_name = restype_1to3[residue_name] residue_chi_angles = chi_angles_atoms[residue_name] atom_pos_indices = [] for chi_angle in residue_chi_angles: atom_pos_indices.append([atom_order[atom] for atom in chi_angle]) for _ in range(4 - len(atom_pos_indices)): atom_pos_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA. chi_atom_pos_indices.append(atom_pos_indices) chi_atom_pos_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. return np.array(chi_atom_pos_indices) def gather(params, indices, axis=0): """gather operation""" func = lambda p, i: np.take(p, i, axis=axis) return func(params, indices) def np_gather_ops(params, indices, axis=0, batch_dims=0): """np gather operation""" if batch_dims == 0: return gather(params, indices) result = [] if batch_dims == 1: for p, i in zip(params, indices): axis = axis - batch_dims if axis - batch_dims > 0 else 0 r = gather(p, i, axis=axis) result.append(r) return np.stack(result) for p, i in zip(params[0], indices[0]): r = gather(p, i, axis=axis) result.append(r) res = np.stack(result) return res.reshape((1,) + res.shape) def rot_to_quat(rot, unstack_inputs=False): """transfer the rotation matrix to quaternion matrix""" if unstack_inputs: rot = [np.moveaxis(x, -1, 0) for x in np.moveaxis(rot, -2, 0)] [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot k = [[xx + yy + zz, zy - yz, xz - zx, yx - xy], [zy - yz, xx - yy - zz, xy + yx, xz + zx], [xz - zx, xy + yx, yy - xx - zz, yz + zy], [yx - xy, xz + zx, yz + zy, zz - xx - yy]] k = (1. / 3.) * np.stack([np.stack(x, axis=-1) for x in k], axis=-2) # compute eigenvalues _, qs = np.linalg.eigh(k) return qs[..., -1] def to_tensor(rotation, translation): """get affine based on rotation and translation""" quaternion = rot_to_quat(rotation) return np.concatenate( [quaternion] + [np.expand_dims(x, axis=-1) for x in translation], axis=-1) def convert_monomer_features(chain_id, aatype, template_aatype): """Reshapes and modifies monomer features for multimer models.""" auth_chain_id = np.asarray(chain_id, dtype=np.object_) new_order_list = MAP_HHBLITS_AATYPE_TO_OUR_AATYPE monomer_aatype = np.argmax(aatype, axis=-1).astype(np.int32) monomer_template_aatype = np.argmax(template_aatype, axis=-1).astype(np.int32) monomer_template_aatype = np.take(new_order_list, monomer_template_aatype.astype(np.int32), axis=0) return auth_chain_id, monomer_aatype, monomer_template_aatype def convert_unnecessary_leading_dim_feats(sequence, domain_name, num_alignments, seq_length): """get first dimension data of unnecessary features.""" monomer_sequence = np.asarray(sequence[0], dtype=sequence.dtype) monomer_domain_name = np.asarray(domain_name[0], dtype=domain_name.dtype) monomer_num_alignments = np.asarray(num_alignments[0], dtype=num_alignments.dtype) monomer_seq_length = np.asarray(seq_length[0], dtype=seq_length.dtype) converted_feature = (monomer_sequence, monomer_domain_name, monomer_num_alignments, monomer_seq_length) return converted_feature def process_unmerged_features(deletion_matrix_int, deletion_matrix_int_all_seq, aatype, entity_id, num_chains): """Postprocessing stage for per-chain features before merging.""" # Convert deletion matrices to float. deletion_matrix = np.asarray(deletion_matrix_int, dtype=np.float32) deletion_matrix_all_seq = np.asarray(deletion_matrix_int_all_seq, dtype=np.float32) all_atom_mask = STANDARD_ATOM_MASK[aatype] all_atom_mask = all_atom_mask all_atom_positions = np.zeros(list(all_atom_mask.shape) + [3]) deletion_mean = np.mean(deletion_matrix, axis=0) # Add assembly_num_chains. assembly_num_chains = np.asarray(num_chains) entity_mask = (entity_id != 0).astype(np.int32) post_feature = (deletion_matrix, deletion_matrix_all_seq, deletion_mean, all_atom_mask, all_atom_positions, assembly_num_chains, entity_mask) return post_feature def get_crop_size(num_alignments_all_seq, msa_all_seq, msa_crop_size, msa_size): """get maximum msa crop size Args: num_alignments_all_seq: num_alignments for all sequence, which record the total number of msa msa_all_seq: un-paired sequences for all msa. msa_crop_size: The total number of sequences to crop from the MSA. msa_size: number of msa Returns: msa_crop_size: msa sized to be cropped msa_crop_size_all_seq: msa_crop_size for features with "_all_seq" """ msa_size_all_seq = num_alignments_all_seq msa_crop_size_all_seq = np.minimum(msa_size_all_seq, msa_crop_size // 2) # We reduce the number of un-paired sequences, by the number of times a # sequence from this chain's MSA is included in the paired MSA. This keeps # the MSA size for each chain roughly constant. msa_all_seq = msa_all_seq[:msa_crop_size_all_seq, :] num_non_gapped_pairs = np.sum(np.any(msa_all_seq != restypes_with_x_and_gap.index('-'), axis=1)) num_non_gapped_pairs = np.minimum(num_non_gapped_pairs, msa_crop_size_all_seq) # Restrict the unpaired crop size so that paired+unpaired sequences do not # exceed msa_seqs_per_chain for each chain. max_msa_crop_size = np.maximum(msa_crop_size - num_non_gapped_pairs, 0) msa_crop_size = np.minimum(msa_size, max_msa_crop_size) return msa_crop_size, msa_crop_size_all_seq def make_seq_mask(entity_id): """seq mask info, True for entity_id > 0, False for entity_id <= 0.""" seq_mask = (entity_id > 0).astype(np.float32) return seq_mask def make_msa_mask(msa, entity_id): """Mask features are all ones, but will later be zero-padded.""" msa_mask = np.ones_like(msa, dtype=np.float32) seq_mask = (entity_id > 0).astype(np.float32) msa_mask *= seq_mask[None] return msa_mask def add_padding(feature_name, feature): """get padding data with specified shapes of feature""" num_res = feature.shape[1] padding = MSA_PAD_VALUES.get(feature_name) * np.ones([1, num_res], feature.dtype) return padding def generate_random_sample(cfg, model_config): '''generate_random_sample''' np.random.seed(0) num_noise = model_config.model.latent.num_noise latent_dim = model_config.model.latent.latent_dim context_true_prob = np.absolute(model_config.train.context_true_prob) keep_prob = np.absolute(model_config.train.keep_prob) available_msa = int(model_config.train.available_msa_fraction * model_config.train.max_msa_clusters) available_msa = min(available_msa, model_config.train.max_msa_clusters) evogen_random_data = np.random.normal( size=(num_noise, model_config.train.max_msa_clusters, cfg.eval.crop_size, latent_dim)).astype(np.float32) # (Nseq,): context_mask = np.zeros((model_config.train.max_msa_clusters,), np.int32) z1 = np.random.random(model_config.train.max_msa_clusters) context_mask = np.asarray([1 if x < context_true_prob else 0 for x in z1], np.int32) context_mask[available_msa:] *= 0 # (Nseq,): target_mask = np.zeros((model_config.train.max_msa_clusters,), np.int32) z2 = np.random.random(model_config.train.max_msa_clusters) target_mask = np.asarray([1 if x < keep_prob else 0 for x in z2], np.int32) context_mask[0] = 1 target_mask[0] = 1 evogen_context_mask = np.stack((context_mask, target_mask), -1) return evogen_random_data, evogen_context_mask