package com.sunyard.chsm.service; import com.sunyard.chsm.auth.UserContext; import com.sunyard.chsm.enums.AlgMode; import com.sunyard.chsm.enums.KeyAlg; import com.sunyard.chsm.enums.KeyCategory; import com.sunyard.chsm.enums.KeyStatus; import com.sunyard.chsm.enums.KeyUsage; import com.sunyard.chsm.enums.Padding; import com.sunyard.chsm.mapper.KeyInfoMapper; import com.sunyard.chsm.mapper.SpKeyRecordMapper; import com.sunyard.chsm.model.entity.KeyInfo; import com.sunyard.chsm.model.entity.KeyRecord; import com.sunyard.chsm.param.SymDecryptReq; import com.sunyard.chsm.param.SymDecryptResp; import com.sunyard.chsm.param.SymEncryptReq; import com.sunyard.chsm.param.SymEncryptResp; import com.sunyard.chsm.param.SymHmacCheckReq; import com.sunyard.chsm.param.SymHmacCheckResp; import com.sunyard.chsm.param.SymHmacReq; import com.sunyard.chsm.param.SymHmacResp; import com.sunyard.chsm.param.SymMacCheckReq; import com.sunyard.chsm.param.SymMacCheckResp; import com.sunyard.chsm.param.SymMacReq; import com.sunyard.chsm.param.SymMacResp; import com.sunyard.chsm.sdf.SdfApiService; import com.sunyard.chsm.sdf.context.AlgId; import com.sunyard.chsm.utils.CodecUtils; import lombok.RequiredArgsConstructor; import org.springframework.stereotype.Service; import org.springframework.util.Assert; import java.time.LocalDateTime; import java.util.Arrays; import java.util.Objects; /** * @author liulu * @since 2024/12/17 */ @Service @RequiredArgsConstructor public class SymKeyService { private final KeyInfoMapper keyInfoMapper; private final SpKeyRecordMapper spKeyRecordMapper; private final SdfApiService sdfApiService; public SymEncryptResp encrypt(SymEncryptReq req) { byte[] iv = new byte[0]; if (AlgMode.CBC == req.getMode()) { Assert.hasText(req.getIv(), "CBC模式iv不能为空"); byte[] bytes = CodecUtils.decodeBase64(req.getIv()); Assert.isTrue(bytes.length >= 16, "iv长度至少为16"); iv = bytes; } byte[] plain = CodecUtils.decodeBase64(req.getPlainData()); KeyInfo keyInfo = checkKey(req.getKeyId(), KeyUsage.ENCRYPT_DECRYPT); KeyAlg keyAlg = KeyAlg.of(keyInfo.getKeyAlg()); Assert.notNull(keyAlg, "数据异常"); AlgId algId = null; switch (keyAlg) { case SM4: if (Padding.PCKS5Padding == req.getPadding()) { req.setPadding(Padding.PCKS7Padding); } switch (req.getMode()) { case ECB: algId = AlgId.SGD_SM4_ECB; break; case CBC: algId = AlgId.SGD_SM4_CBC; break; } break; default: throw new UnsupportedOperationException("不支持的密钥算法:" + keyAlg.getCode()); } KeyRecord keyRecord = spKeyRecordMapper.selectUsedByKeyId(keyInfo.getId()); Assert.notNull(keyRecord, "数据异常"); byte[] symKey = sdfApiService.decryptByTMK(CodecUtils.decodeHex(keyRecord.getKeyData())); byte[] cipherData = sdfApiService.symEncrypt(algId, req.getPadding(), symKey, iv, plain); SymEncryptResp resp = new SymEncryptResp(); resp.setKeyId(keyInfo.getId()); resp.setKeyIndex(keyRecord.getKeyIndex()); resp.setCipherData(CodecUtils.encodeBase64(cipherData)); return resp; } public SymDecryptResp decrypt(SymDecryptReq req) { byte[] iv = new byte[0]; if (AlgMode.CBC == req.getMode()) { Assert.hasText(req.getIv(), "CBC模式iv不能为空"); byte[] bytes = CodecUtils.decodeBase64(req.getIv()); Assert.isTrue(bytes.length >= 16, "iv长度至少为16"); iv = bytes; } byte[] cipher = CodecUtils.decodeBase64(req.getCipherData()); KeyInfo keyInfo = checkKey(req.getKeyId(), KeyUsage.ENCRYPT_DECRYPT); KeyAlg keyAlg = KeyAlg.of(keyInfo.getKeyAlg()); Assert.notNull(keyAlg, "数据异常"); AlgId algId = null; switch (keyAlg) { case SM4: if (Padding.PCKS5Padding == req.getPadding()) { req.setPadding(Padding.PCKS7Padding); } switch (req.getMode()) { case ECB: algId = AlgId.SGD_SM4_ECB; break; case CBC: algId = AlgId.SGD_SM4_CBC; break; } break; default: throw new UnsupportedOperationException("不支持的密钥算法:" + keyAlg.getCode()); } KeyRecord keyRecord = spKeyRecordMapper.selectById(Long.valueOf(req.getKeyIndex())); Assert.notNull(keyRecord, "数据异常"); Assert.isTrue(Objects.equals(keyRecord.getKeyId(), keyInfo.getId()), "密钥Id和密钥索引不匹配"); byte[] symKey = sdfApiService.decryptByTMK(CodecUtils.decodeHex(keyRecord.getKeyData())); byte[] plain = sdfApiService.symDecrypt(algId, req.getPadding(), symKey, iv, cipher); SymDecryptResp resp = new SymDecryptResp(); resp.setKeyId(keyInfo.getId()); resp.setKeyIndex(keyRecord.getKeyIndex()); resp.setPlainData(CodecUtils.encodeBase64(plain)); return resp; } public SymHmacResp hmac(SymHmacReq req) { byte[] plain = CodecUtils.decodeBase64(req.getPlainData()); KeyInfo keyInfo = checkKey(req.getKeyId(), KeyUsage.HMAC); KeyRecord keyRecord = spKeyRecordMapper.selectUsedByKeyId(keyInfo.getId()); byte[] symKey = sdfApiService.decryptByTMK(CodecUtils.decodeHex(keyRecord.getKeyData())); byte[] hmac = sdfApiService.hmac(symKey, plain); SymHmacResp resp = new SymHmacResp(); resp.setKeyId(keyInfo.getId()); resp.setKeyIndex(keyRecord.getKeyIndex()); resp.setHmac(CodecUtils.encodeBase64(hmac)); return resp; } public SymHmacCheckResp hmacCheck(SymHmacCheckReq req) { byte[] plain = CodecUtils.decodeBase64(req.getPlainData()); byte[] originHmac = CodecUtils.decodeBase64(req.getHmac()); KeyInfo keyInfo = checkKey(req.getKeyId(), KeyUsage.HMAC); KeyRecord keyRecord = spKeyRecordMapper.selectById(Long.valueOf(req.getKeyIndex())); Assert.notNull(keyRecord, "数据异常"); Assert.isTrue(Objects.equals(keyRecord.getKeyId(), keyInfo.getId()), "密钥Id和密钥索引不匹配"); byte[] symKey = sdfApiService.decryptByTMK(CodecUtils.decodeHex(keyRecord.getKeyData())); byte[] hmac = sdfApiService.hmac(symKey, plain); SymHmacCheckResp resp = new SymHmacCheckResp(); resp.setValid(Arrays.equals(hmac, originHmac)); return resp; } public SymMacResp mac(SymMacReq req) { byte[] plain = CodecUtils.decodeBase64(req.getPlainData()); byte[] iv = CodecUtils.decodeBase64(req.getIv()); Assert.isTrue(iv.length >= 16, "iv长度至少为16"); KeyInfo keyInfo = checkKey(req.getKeyId(), KeyUsage.MAC); KeyRecord keyRecord = spKeyRecordMapper.selectUsedByKeyId(keyInfo.getId()); Assert.notNull(keyRecord, "数据异常"); if (Padding.PCKS5Padding == req.getPadding()) { req.setPadding(Padding.PCKS7Padding); } byte[] symKey = sdfApiService.decryptByTMK(CodecUtils.decodeHex(keyRecord.getKeyData())); byte[] mac = sdfApiService.calculateMAC(AlgId.SGD_SM4_MAC, req.getPadding(), symKey, iv, plain); SymMacResp resp = new SymMacResp(); resp.setKeyId(keyInfo.getId()); resp.setKeyIndex(keyRecord.getKeyIndex()); resp.setMac(CodecUtils.encodeBase64(mac)); return resp; } public SymMacCheckResp macCheck(SymMacCheckReq req) { byte[] plain = CodecUtils.decodeBase64(req.getPlainData()); byte[] iv = CodecUtils.decodeBase64(req.getIv()); byte[] originMac = CodecUtils.decodeBase64(req.getMac()); KeyInfo keyInfo = checkKey(req.getKeyId(), KeyUsage.MAC); KeyRecord keyRecord = spKeyRecordMapper.selectById(Long.valueOf(req.getKeyIndex())); Assert.notNull(keyRecord, "数据异常"); Assert.isTrue(Objects.equals(keyRecord.getKeyId(), keyInfo.getId()), "密钥Id和密钥索引不匹配"); if (Padding.PCKS5Padding == req.getPadding()) { req.setPadding(Padding.PCKS7Padding); } byte[] symKey = sdfApiService.decryptByTMK(CodecUtils.decodeHex(keyRecord.getKeyData())); byte[] mac = sdfApiService.calculateMAC(AlgId.SGD_SM4_MAC, req.getPadding(), symKey, iv, plain); SymMacCheckResp resp = new SymMacCheckResp(); resp.setValid(Arrays.equals(mac, originMac)); return resp; } private KeyInfo checkKey(Long keyId, KeyUsage usage) { KeyInfo keyInfo = keyInfoMapper.selectById(keyId); Assert.notNull(keyInfo, "密钥ID不存在"); Assert.isTrue(Objects.equals(keyInfo.getApplicationId(), UserContext.getCurrentAppId()), "您无权使用此密钥ID"); Assert.isTrue(KeyCategory.SYM_KEY.getCode().equals(keyInfo.getKeyType()), "此密钥不是对称密钥"); KeyStatus status = KeyStatus.of(keyInfo.getStatus()); LocalDateTime now = LocalDateTime.now(); Assert.isTrue(KeyStatus.ENABLED == status, "此密钥不是启用状态, 无法操作"); Assert.isTrue(now.isAfter(keyInfo.getEffectiveTime()) && now.isBefore(keyInfo.getExpiredTime()), "此密钥不是启用状态, 无法操作"); Assert.isTrue(KeyUsage.hasUsage(keyInfo.getKeyUsage(), usage), "此密钥无权进行" + usage.getDesc() + "操作"); return keyInfo; } }