联邦学习 | 如何使得模型训练精度不受影响的同时,数据隐私也不泄露呢?
联邦学习 | 如何使得模型训练精度不受影响的同时,数据隐私也不泄露呢?

作者:AI安全Mr.Jin |来源:知乎
在上一节我们讲到了可以使用差分隐私训练算法来保护数据隐私,并且其主要的原理是对敏感数据(模型的权重或更新的梯度)进行加噪来隐藏真实的值。确定加噪的多少也是一个技术活,如果加噪太多,隐私保护效果会很好,但是模型训练可能不收敛;如果加噪太少,隐私保护效果可能不太好。那么有没有办法使得模型训练精度不受影响的同时,数据隐私也不泄露呢?

还真有,那就是这次讲的基于密码学的“安全聚合训练”[1][2]。安全聚合训练的总流程比较复杂,我们接下来从最简单的场景逐步分析(想直接看完整流程的同学可以跳到第6小节哦!)。
1,pairwise encryption
假设每个参与训练的客户端 u 的梯度值是 xu ,中心服务器采用平均聚合的方式求梯度值: x¯=(∑u∈Uxu)/|U| ( |U| 为客户端的数量),那么中心服务器实际上只需要知道 ∑u∈Uxu 的值,而不一定需要知道每个 xu 的值。那么怎么实现呢?只需要这样做:
① 所有客户端两两之间建立安全信道。对于每个客户端 u ,它为其它每个 v∈U 生成一个和 xu 同维度的噪声向量 su,v 。
② 所有客户端两两之间通过安全信道交换它们为对方生成的噪声向量。比如 u 把 su,v 发给 v , v 把 sv,u 发给 u 。并且 u 计算 pu,v=su,v−sv,u 作为它和 v 之间协商的秘密扰动, v 计算 pv,u=sv,u−su,v 作为它和 u 之间的秘密扰动。可以发现, pu,v=−pv,u(pu,u=0) 。
③ 每一个客户端 u 把它和其它客户端之间协商的秘密扰动加到 xu 上:
yu=xu+∑v∈Upu,v 。
④ 每个客户端把 yu 发送给中心服务器,而不是发送 xu 。
⑤ 中心服务器收到所有客户端发送的 yu 之后,进行求和: x¯=∑u∈Uyu 。下面证明 ∑u∈Uyu=∑u∈Uxu :
∑u∈Uyu=∑u∈U(xu+∑v∈Upu,v)=∑u∈Uxu+∑u∈U∑v∈Upu,v
注意到 pu,v=−pv,u ,所以 ∑u∈U∑v∈Upu,v=0 ,从而上式为: ∑u∈Uyu=∑u∈Uxu 。也就是说,尽管每个客户端发送给中心服务器的值是经过加扰的,但是中心服务器求和的结果和真实的 xu 之和是完全相等的。
是不是很奇妙?中心服务器不需要知道每个 xu 的真实值,却可以把 ∑u∈Uxu 的值求出来,我们把这种方式叫做“pairwise encryption”。
2,有客户端掉线怎么办?
在真实的端云联邦学习场景中,客户端是手机,手机的状态是不稳定的。如果某台手机完成了上小节中的①②③之后掉线了,没有完成第④步,那么各客户端之间的秘密扰动就不能完成抵消了。举个例子,假如客户端 k 掉线了,那么中心服务器的聚合结果为:
∑u∈U,u≠kyu=∑u∈U,u≠k(xu+∑v∈Upu,v)=∑u∈U,u≠kxu+∑u∈U,u≠k∑v∈Upu,v=∑u∈U,u≠kxu+∑u∈U,u≠k(∑v∈U,v≠kpu,v+pu,k)=∑u∈U,u≠kxu+∑u∈U,u≠k∑v∈U,v≠kpu,v+∑u∈U,u≠kpu,k=∑u∈U,u≠kxu+∑u∈U,u≠kpu,k
可以看到, ∑u∈U,u≠kpu,k 这部分秘密扰动没有抵消掉,会导致计算错误。这部分秘密扰动是其它客户端和 k 之间 协商好的,它们把秘密扰动加到梯度向量上、并且发送给中心服务器之后, k 出了意外,没有把 yk 上传给中心服务器,从而导致和 k 有关的秘密扰动都抵消不了。
那么怎么办呢?可以使用密钥重构机制来恢复掉线客户端的秘密扰动(粗体部分是相比于第一节新增的步骤):
① 所有客户端两两之间建立安全信道。对于每个客户端 u ,它为其它每个 v∈U 生成一个和 xu 同维度的噪声向量 su,v 。
② 每个客户端生成一对非对称加密的公私钥(公钥加密、私钥解密),然后所有客户端通过安全信道互相发送公钥。这样的话,后续客户端之间就通过中心服务器进行信息交换,而不是通过安全信道。当然,客户端 u 通过中心服务器把数据发送给客户端 v 之前,会使用 v 发给它的公钥对数据进行加密,这样的话中心服务器就不知道真实数据了。这样做的好处是中心服务器知道每一轮有哪些客户端参与了训练。
③ 所有客户端两两之间交换(通过原有的安全信道直接交换或者通过中心服务器间接交换都可以)它们为对方生成的噪声向量。比如 u 把 su,v 发给 v , v 把 sv,u 发给 u 。并且 u 计算 pu,v=su,v−sv,u 作为它和 v 之间协商的秘密扰动, v 计算 pv,u=sv,u−su,v 作为它和 u 之间的秘密扰动。可以发现, pu,v=−pv,u(pu,u=0) 。
④ 假设 n=|U| ,每一个客户端 u 使用 (t,n) 秘密分享协议(比如Shamir's Secret Sharing协议,其功能是把一个明文数据拆成 n 份碎片,至少需要碎片中的 t 份才能还原原来的明文数据)把每一个 pu,v(v∈U) 分成 n 个秘密碎片 pu,v,i(i=1,2,...,n) ,然后使用第②步收到的 n 个客户端的公钥 pubkeyi(i=1,2,...,n) 分别加密这 n 个碎片,然后通过中心服务器把 enc(pu,v,i,pubkeyi) 发送给客户端 i 。
⑤ 假设中心服务器收到了集合为 U1 的客户端发送过来的秘密碎片消息( U1⊆U ,可能有客户端没有完成第④步就掉线了),然后把所有碎片按照转发目标进行分类整理,整理后再一起发送给对应的目标客户端。
⑥ 每个客户端 u 收到了中心服务器转发过来的秘密碎片,并且知道了在线的客户端集合是 U1 ,那 pu,v(u∈U1/U2,v∈U2) 么 u 把它和 U1 中每个客户端之间协商的秘密扰动加到 xu **上:**yu=xu+∑v∈U1pu,v 。
⑦ 每个客户端把 yu 发送给中心服务器,而不是发送 xu 。
⑧ 中心服务器收到了集合为 U2 的客户端发送过来的 yu ( U2⊆U1 ,可能有客户端在⑥⑦掉线了)。如果 U2 是 U1 的真子集,中心服务器直接进行聚合: x¯=∑u∈U2yu=∑u∈U2(xu+∑v∈U1pu,v) ,则会出现本小节开头分析的秘密扰动不能抵消的情况。那么怎么办呢?中心服务器需要做如下操作:
⑨ 中心服务器确定掉线的客户端名单 U1/U2**(这里应该是'\'个符号,表述除去,不知道知乎的公式为啥打不出来),然后向所有** U2 中的客户端发送请求,让它们每人返回一个 pu,v(u∈U1/U2,v∈U2) 的碎片,也就是每个客户端要返回 |U1/U2|∗|U2| 个碎片。当 |U2|≥t ,中心服务器就可以恢复出了。于是中心服务器这样聚合:
x¯=∑u∈U2yu+∑u∈U1/U2∑v∈U2pu,v=∑u∈U2(xu+∑v∈U1pu,v)+∑u∈U1/U2∑v∈U2pu,v=∑u∈U2(xu+∑v∈U2pu,v+∑v∈U1/U2pu,v)+∑u∈U1/U2∑v∈U2pu,v=∑u∈U2xu
秘密重构的思想其实很简单,就是把各个客户端的秘密扰动拆成碎片后保管在其它客户端手里,一旦某个客户端在训练过程中掉线了,那么中心服务器就可以利用其它在线客户端手里的碎片来恢复出掉线客户端的秘密扰动,从而抵消所有的秘密扰动。
3,Double masking
尽管密钥恢复机制可以提升联邦学习系统对客户端掉线场景的鲁棒性,但是当中心服务器被攻击者控制后,客户端的明文数据是可能会被泄露的。举个例子,假如客户端 u 把加密后的数据 yu=xu+∑v∈U1pu,v 发送给了中心服务器。中心服务器在密钥恢复阶段欺骗其它客户端,说客户端 u 掉线了,从而获取 uu,v(v∈U1) 的碎片,进而恢复出 ∑v∈U1pu,v ,最终通过计算 yu−∑v∈U1pu,v 获得了客户端 u 的真实值。
于是,算法进行改进,引入了“Double masking”来防御中心服务器的恶意行为。也就是每个客户端除了生成 su,v ,还要生成秘密扰动 bu ,这部分秘密扰动不需要和其它客户端协商生成。然后在对 pu,v进行秘密分享的时候,也对 bu 做秘密分享,最后把 yu=xu+∑v∈U1pu,v+bu 发送给中心服务器。
于是上一小节的第⑧步变成了:
⑧ 中心服务器收到了集合为 U2 的客户端发送过来的 yu ( U2⊆U1 ,可能有客户端在⑥⑦掉线了)。如果 U2 是 U1 的真子集,中心服务器直接进行聚合: x¯=∑u∈U2yu=∑u∈U2(xu+∑v∈U1pu,v+bu)=∑u∈U2(xu+∑v∈U1pu,v)+∑u∈U2bu
于是不能抵消的秘密扰动除了掉线客户端的 pu,v ,还有在线客户端的 bu 。接下来一步是重点:
协议规定,中心服务器在向任意在线客户端 u∈U2 索要其它客户端 k∈U1 的秘密碎片的时候,要么只能获取 pk,v(v∈U1) 相关的,要么只能获取 bk 相关的,不能两者都要。
这样的话,中心服务器在向在线客户端索要秘密碎片的时候,对于 k∈U2 ,可以索要 bk 的碎片,而对于 k∈U1/U2 ,选择 pk,v 的碎片。这样的话,中心服务器就可以恢复出在线客户端的 bu 和离线客户端的 pu,v 。从而把所有秘密扰动抵消,同时中心服务器不能抵消掉任一客户端的 bu 和 pu,v ,保护了 xu 的安全。

网图,侵删
但是有聪明的同学问了,如果中心服务器向一半人索要 bu 的碎片,向一半人索要 su,v 的碎片,不就可以恢复出 u 的真实数据吗?对!所以为了保证安全性,一般需要秘密恢复阈值 t>n/2 ( t 参考第2小节的第④部分)。
4,通信效率优化
注意到客户端之间在协商秘密扰动的时候,需要互相发送 su,v ,并且 su,v 的维度和被保护数据 xu 的维度大小是一样的。想象一下,如果被保护的是100万维度的模型参数,并且有100个客户端进行训练,那么每个客户端需要传输1亿个数据,太费流量了!
于是算法可以如下进一步优化:
每个客户端生成一对Diffifie-Hellman公私钥,公钥是 sPK ,私钥是 sSK 。然后每个客户端把自己的公钥发送给中心服务器,中心服务器收集满 suPK(u∈U) 之后,再发送给所有客户端。
接下来客户端 u 这样确定它和 v 之间的秘密扰动 pu,v :先计算 su,v=AGREE(suSK,svPK) ,然后 pu,v=PRG(su,v) (如果 u 的序号大于 v 的序号),pu,v=−PRG(su,v) (如果 u 的序号小于 v 的序号)。其中 PRG(x) 指的是pseudorandom generator,也就是伪随机数生成器,当 x=y , PRG(x)=PRG(y) ,而且 PRG(x) 的维度和 xu 的维度相同。
客户端 v 这样确定它和 u 之间的秘密扰动 pv,u :先计算 sv,u=AGREE(svSK,suPK) ,然后 pv,u=PRG(sv,u) (如果 u 的序号大于 v 的序号),pv,u=−PRG(sv,u) (如果 u 的序号小于 v 的序号)。注意,Diffifie-Hellman密钥有个特性,就是 AGREE(suSK,svPK)=AGREE(svSK,suPK) ,从而 sv,u=su,v ,进而 pu,v+pv,u=0 ,满足前面小节中 pu,v 的性质。
5,客户端的身份验证
按照以上步骤训练就安全了吗 ?那也不一定,因为中心服务器可能会发起中间人攻击。举个例子,比如有A、B、C、D四个客户端,当他们把自己的公钥 sAPK 、 sBPK 、sCPK 、 sDPK发送给中心服务器之后,服务器在本地生成4对公私钥 (sA′PK,sA′SK) 、(sB′PK,sB′SK) 、(sC′PK,sC′SK) 、(sD′PK,sD′SK) ,然后用 sA′PK 、sB′PK 、sC′PK 、sD′PK 替换掉sAPK 、 sBPK 、sCPK 、 sDPK并发送给A、B、C、D。这样的话,客户端A在真实数据上加的秘密扰动就是 PRG(AGREE(sASK,sB′PK)) +PRG(AGREE(sASK,sC′PK)) +PRG(AGREE(sASK,sD′PK)) ,而中心服务器拥有 sA′SK 、sBPK,且 AGREE(sA′SK,sBPK) = AGREE(sASK,sB′PK) (另外两项同理),所以中心服务器可以计算出客户端A发送的数据中的 su,v 类秘密扰动;对于 bu 类的秘密扰动,可以采用第3小节中的方式去还原出来,从而还原出客户端A的原始数据。
那么怎么去进行防御呢?我们注意到,中心服务器能完成攻击是因为它把客户端想要转发的数据调包成了它自己生成的数据。如果我们给每个客户端增加一个数字签名校验机制,验证中心服务器转发过来的数据确实是其它客户端发送的,就可以抵御这类攻击了。数字签名校验的形式是:所有客户端共享一套权威授权的证书,客户端在发送数据之前,可以利用证书中的信息对数据进行签名,然后把数据和签名发送给另一客户端;另一客户端接收到数据和签名后,会利用证书中的信息对数据和签名进行验证,如果验证通过,说明数据是身份合法的客户端发送过来的(下回专门写一篇数字签名的文章,小伙伴可以在评论区提醒我噢)。
好了,基于以上的内容,我们最终就可以得到一个完整版的安全聚合步骤[2]:
6,完整的安全聚合训练步骤

文献[2]

文献[2]
中文翻译:
部署
所有客户端通过可信第三方获得数字签名认证需要的签名密钥 duSK 和对其它客户端数据验签的密钥duSK,一般是根据客户端ID的哈希值去产生密钥。
Round 0 (公钥广播)
客户端 u :
1,产生两对非对称加解密的公、私钥 (cuPK,cuSK) ,(suPK,suSK),前者用于对发送、接收的数据进行加、解密,后者用于协商秘密扰动种子;
2,对两个公钥签名,得到 σu=sign(duSK,cuPK||suPK) ,并且把 cuPK, suPK , σu 发送给中心服务器(以下简称Server)。
Server:
接收至少t 个客户端发送过来的消息(把这些客户端记作 u1 ),并且把数据列表 (v,cvPK,svPK,σv)v∈u1 发送给u1中的所有客户端。
Round 1 (秘钥分享)
客户端 u :
1,收到Server发送的 (v,cvPK,svPK,σv)v∈u1 之后,确认 |u1>t| ,以及各组秘钥都不相同,并且对所有的签名做校验: ∀v∈u1,SIG.ver(dvPK,cvPK||svPK,σu)=1 。
2,生成一个随机种子 bu ,并且把bu和 suSK 进行秘密分享,得到秘密碎片 (v,bu,v)v∈u1 和 (v,su,vSK)v∈u1 。
3,把碎片进行加密: eu,v=AE.enc(KA.agree(cuSK,cvPK),u||v||su,vSK||bu,v||) ,然后把 (eu,v,u,v) 发给Server。注意,此处加密消息的秘钥是 KA.agree(cuSK,cvPK) ,这样的话,Server把这个消息转发给客户端 v 之后, v 就可以用秘钥 KA.agree(cvSK,cuPK) 去解密。此外, u 把 (eu,v,u,v) 发送给Server是为了告诉Server这个“快递”是发给客户端 v 的。
Server:
接收至少 t 个客户端发送过来的加密信息,把这些客户端记作 u2 ,且要验证 u2⊆u1 ;
然后对于所有 u∈u2 ,把 {ev,u}v∈u2 发送给客户端 u 。
Round 2 (数据加密)
客户端 u :
1,收到Server发送过来的加密信息 {ev,u}v∈u2 ,确认 |u2|>t ,保存好。
2,对于所有 v∈u2/u ,计算秘密扰动种子 su,v=KA.agree(suSK,svPK) ,然后使用PRG算法生成扰动向量 pu,v=Δu,v⋅PRG(su,v) 。当u>v , Δu,v=1;当u<v , Δu,v=−1。
3,计算自身的秘密种子生成的扰动向量 pu=PRG(bu) ,然后把扰动向量加到真实数据上,得到加密数据: yu=xu+pu+∑v∈u2pu,v ,然后把 yu 发送给Server。
Server:
接收至少 t 个客户端发送过来的加密信息,把这些客户端记作 u3 ,且要验证 u3⊆u2 ,然后把用户列表 u3 发送给这些客户端。
Round 3 (一致性校验)
客户端 u :
接收Server发送过来的 u3 列表,确认 |u3|>t ,然后对列表签名得到 σu′=SIG.sign(duSK,u3) ,再把 σu′ 发送给Server。
Server:
接收至少 t 个客户端发送过来的签名信息,把这些客户端记作 u4 ,且要验证 u4⊆u3 ,然后把 (v,σv′)v∈u4 发送给这些客户端。
Round 4 (解密)
客户端 u :
1,接收Server发送过来的签名列表 (v,σv′)v∈u4 ,确认 u4⊆u3 ,以及 |u4|>t 。并且对于所有 v∈u4 ,验证 SIG.ver(dvPK,u3,σv′)=1 。注意,到这一步才算完成了一致性校验,一致性校验可以保证各客户端对所有客户端的在线情况认知是一致的,从而保证解密过程的正确性。
2,对于所有的 v∈u2 ,解密Round 2收到的 ev,u ,得到 v′||u′||sv,uSK||bv,u=AE.dec(KA.agree(cuSK,cvPK),ev,u) ;
3,把 sv,uSK(v∈u2/u3) 和 bv,u(v∈u3) 发送给Server。
Server:
接收至少 t 个客户端发送过来的碎片信息,把这些客户端记作 u5 。
对于掉线的客户端 u∈u2/u3 ,重建它的私钥 suSK=SS.recon((su,vSK)v∈u5,t) ,并使用它和 u3 的公钥生成 pu,v(v∈u3) 。
对于在线的客户端 u∈u3 ,重建它的个人秘钥 bu=SS.recon((bu,v)v∈u5,t) ,并生成 pu 。
计算最终的聚合值:
z=∑u∈u3yu−∑u∈u3pu+∑u∈u3,v∈u2/u3pv,u=∑u∈u3(xu+pu+∑v∈u2pu,v)−∑u∈u3pu+∑u∈u3,v∈u2/u3pv,u=∑u∈u3xu
参考
- 1.Practical Secure Aggregation for Federated Learning on User-Held Data https://arxiv.org/pdf/1611.04482.pdf
- 2.abPractical Secure Aggregation for Privacy-Preserving Machine Learning https://dl.acm.org/doi/pdf/10.1145/3133956.3133982