chsm-server/chsm-web-server/src/main/java/com/sunyard/chsm/service/SymKeyService.java
2024-12-19 16:42:47 +08:00

232 lines
9.6 KiB
Java

package com.sunyard.chsm.service;
import com.sunyard.chsm.auth.UserContext;
import com.sunyard.chsm.enums.AlgMode;
import com.sunyard.chsm.enums.KeyAlg;
import com.sunyard.chsm.enums.KeyCategory;
import com.sunyard.chsm.enums.KeyStatus;
import com.sunyard.chsm.enums.KeyUsage;
import com.sunyard.chsm.enums.Padding;
import com.sunyard.chsm.mapper.KeyInfoMapper;
import com.sunyard.chsm.mapper.SpKeyRecordMapper;
import com.sunyard.chsm.model.entity.KeyInfo;
import com.sunyard.chsm.model.entity.KeyRecord;
import com.sunyard.chsm.param.SymDecryptReq;
import com.sunyard.chsm.param.SymDecryptResp;
import com.sunyard.chsm.param.SymEncryptReq;
import com.sunyard.chsm.param.SymEncryptResp;
import com.sunyard.chsm.param.SymHmacCheckReq;
import com.sunyard.chsm.param.SymHmacCheckResp;
import com.sunyard.chsm.param.SymHmacReq;
import com.sunyard.chsm.param.SymHmacResp;
import com.sunyard.chsm.param.SymMacCheckReq;
import com.sunyard.chsm.param.SymMacCheckResp;
import com.sunyard.chsm.param.SymMacReq;
import com.sunyard.chsm.param.SymMacResp;
import com.sunyard.chsm.sdf.SdfApiService;
import com.sunyard.chsm.sdf.context.AlgId;
import com.sunyard.chsm.utils.CodecUtils;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service;
import org.springframework.util.Assert;
import java.time.LocalDateTime;
import java.util.Arrays;
import java.util.Objects;
/**
* @author liulu
* @since 2024/12/17
*/
@Service
@RequiredArgsConstructor
public class SymKeyService {
private final KeyInfoMapper keyInfoMapper;
private final SpKeyRecordMapper spKeyRecordMapper;
private final SdfApiService sdfApiService;
public SymEncryptResp encrypt(SymEncryptReq req) {
byte[] iv = new byte[0];
if (AlgMode.CBC == req.getMode()) {
Assert.hasText(req.getIv(), "CBC模式iv不能为空");
byte[] bytes = CodecUtils.decodeBase64(req.getIv());
Assert.isTrue(bytes.length >= 16, "iv长度至少为16");
iv = bytes;
}
byte[] plain = CodecUtils.decodeBase64(req.getPlainData());
KeyInfo keyInfo = checkKey(req.getKeyId(), KeyUsage.ENCRYPT_DECRYPT);
KeyAlg keyAlg = KeyAlg.of(keyInfo.getKeyAlg());
Assert.notNull(keyAlg, "数据异常");
AlgId algId = null;
switch (keyAlg) {
case SM4:
if (Padding.PCKS5Padding == req.getPadding()) {
req.setPadding(Padding.PCKS7Padding);
}
switch (req.getMode()) {
case ECB:
algId = AlgId.SGD_SM4_ECB;
break;
case CBC:
algId = AlgId.SGD_SM4_CBC;
break;
}
break;
default:
throw new UnsupportedOperationException("不支持的密钥算法:" + keyAlg.getCode());
}
KeyRecord keyRecord = spKeyRecordMapper.selectUsedByKeyId(keyInfo.getId());
Assert.notNull(keyRecord, "数据异常");
byte[] symKey = sdfApiService.decryptByTMK(CodecUtils.decodeHex(keyRecord.getKeyData()));
byte[] cipherData = sdfApiService.symEncrypt(algId, req.getPadding(), symKey, iv, plain);
SymEncryptResp resp = new SymEncryptResp();
resp.setKeyId(keyInfo.getId());
resp.setKeyIndex(keyRecord.getKeyIndex());
resp.setCipherData(CodecUtils.encodeBase64(cipherData));
return resp;
}
public SymDecryptResp decrypt(SymDecryptReq req) {
byte[] iv = new byte[0];
if (AlgMode.CBC == req.getMode()) {
Assert.hasText(req.getIv(), "CBC模式iv不能为空");
byte[] bytes = CodecUtils.decodeBase64(req.getIv());
Assert.isTrue(bytes.length >= 16, "iv长度至少为16");
iv = bytes;
}
byte[] cipher = CodecUtils.decodeBase64(req.getCipherData());
KeyInfo keyInfo = checkKey(req.getKeyId(), KeyUsage.ENCRYPT_DECRYPT);
KeyAlg keyAlg = KeyAlg.of(keyInfo.getKeyAlg());
Assert.notNull(keyAlg, "数据异常");
AlgId algId = null;
switch (keyAlg) {
case SM4:
if (Padding.PCKS5Padding == req.getPadding()) {
req.setPadding(Padding.PCKS7Padding);
}
switch (req.getMode()) {
case ECB:
algId = AlgId.SGD_SM4_ECB;
break;
case CBC:
algId = AlgId.SGD_SM4_CBC;
break;
}
break;
default:
throw new UnsupportedOperationException("不支持的密钥算法:" + keyAlg.getCode());
}
KeyRecord keyRecord = spKeyRecordMapper.selectById(Long.valueOf(req.getKeyIndex()));
Assert.notNull(keyRecord, "数据异常");
Assert.isTrue(Objects.equals(keyRecord.getKeyId(), keyInfo.getId()), "密钥Id和密钥索引不匹配");
byte[] symKey = sdfApiService.decryptByTMK(CodecUtils.decodeHex(keyRecord.getKeyData()));
byte[] plain = sdfApiService.symDecrypt(algId, req.getPadding(), symKey, iv, cipher);
SymDecryptResp resp = new SymDecryptResp();
resp.setKeyId(keyInfo.getId());
resp.setKeyIndex(keyRecord.getKeyIndex());
resp.setPlainData(CodecUtils.encodeBase64(plain));
return resp;
}
public SymHmacResp hmac(SymHmacReq req) {
byte[] plain = CodecUtils.decodeBase64(req.getPlainData());
KeyInfo keyInfo = checkKey(req.getKeyId(), KeyUsage.HMAC);
KeyRecord keyRecord = spKeyRecordMapper.selectUsedByKeyId(keyInfo.getId());
byte[] symKey = sdfApiService.decryptByTMK(CodecUtils.decodeHex(keyRecord.getKeyData()));
byte[] hmac = sdfApiService.hmac(symKey, plain);
SymHmacResp resp = new SymHmacResp();
resp.setKeyId(keyInfo.getId());
resp.setKeyIndex(keyRecord.getKeyIndex());
resp.setHmac(CodecUtils.encodeBase64(hmac));
return resp;
}
public SymHmacCheckResp hmacCheck(SymHmacCheckReq req) {
byte[] plain = CodecUtils.decodeBase64(req.getPlainData());
byte[] originHmac = CodecUtils.decodeBase64(req.getHmac());
KeyInfo keyInfo = checkKey(req.getKeyId(), KeyUsage.HMAC);
KeyRecord keyRecord = spKeyRecordMapper.selectById(Long.valueOf(req.getKeyIndex()));
Assert.notNull(keyRecord, "数据异常");
Assert.isTrue(Objects.equals(keyRecord.getKeyId(), keyInfo.getId()), "密钥Id和密钥索引不匹配");
byte[] symKey = sdfApiService.decryptByTMK(CodecUtils.decodeHex(keyRecord.getKeyData()));
byte[] hmac = sdfApiService.hmac(symKey, plain);
SymHmacCheckResp resp = new SymHmacCheckResp();
resp.setValid(Arrays.equals(hmac, originHmac));
return resp;
}
public SymMacResp mac(SymMacReq req) {
byte[] plain = CodecUtils.decodeBase64(req.getPlainData());
byte[] iv = CodecUtils.decodeBase64(req.getIv());
Assert.isTrue(iv.length >= 16, "iv长度至少为16");
KeyInfo keyInfo = checkKey(req.getKeyId(), KeyUsage.MAC);
KeyRecord keyRecord = spKeyRecordMapper.selectUsedByKeyId(keyInfo.getId());
Assert.notNull(keyRecord, "数据异常");
if (Padding.PCKS5Padding == req.getPadding()) {
req.setPadding(Padding.PCKS7Padding);
}
byte[] symKey = sdfApiService.decryptByTMK(CodecUtils.decodeHex(keyRecord.getKeyData()));
byte[] mac = sdfApiService.calculateMAC(AlgId.SGD_SM4_MAC, req.getPadding(), symKey, iv, plain);
SymMacResp resp = new SymMacResp();
resp.setKeyId(keyInfo.getId());
resp.setKeyIndex(keyRecord.getKeyIndex());
resp.setMac(CodecUtils.encodeBase64(mac));
return resp;
}
public SymMacCheckResp macCheck(SymMacCheckReq req) {
byte[] plain = CodecUtils.decodeBase64(req.getPlainData());
byte[] iv = CodecUtils.decodeBase64(req.getIv());
byte[] originMac = CodecUtils.decodeBase64(req.getMac());
KeyInfo keyInfo = checkKey(req.getKeyId(), KeyUsage.MAC);
KeyRecord keyRecord = spKeyRecordMapper.selectById(Long.valueOf(req.getKeyIndex()));
Assert.notNull(keyRecord, "数据异常");
Assert.isTrue(Objects.equals(keyRecord.getKeyId(), keyInfo.getId()), "密钥Id和密钥索引不匹配");
if (Padding.PCKS5Padding == req.getPadding()) {
req.setPadding(Padding.PCKS7Padding);
}
byte[] symKey = sdfApiService.decryptByTMK(CodecUtils.decodeHex(keyRecord.getKeyData()));
byte[] mac = sdfApiService.calculateMAC(AlgId.SGD_SM4_MAC, req.getPadding(), symKey, iv, plain);
SymMacCheckResp resp = new SymMacCheckResp();
resp.setValid(Arrays.equals(mac, originMac));
return resp;
}
private KeyInfo checkKey(Long keyId, KeyUsage usage) {
KeyInfo keyInfo = keyInfoMapper.selectById(keyId);
Assert.notNull(keyInfo, "密钥ID不存在");
Assert.isTrue(Objects.equals(keyInfo.getApplicationId(), UserContext.getCurrentAppId()), "您无权使用此密钥ID");
Assert.isTrue(KeyCategory.SYM_KEY.getCode().equals(keyInfo.getKeyType()), "此密钥不是对称密钥");
KeyStatus status = KeyStatus.of(keyInfo.getStatus());
LocalDateTime now = LocalDateTime.now();
Assert.isTrue(KeyStatus.ENABLED == status, "此密钥不是启用状态, 无法操作");
Assert.isTrue(now.isAfter(keyInfo.getEffectiveTime()) && now.isBefore(keyInfo.getExpiredTime()), "此密钥不是启用状态, 无法操作");
Assert.isTrue(KeyUsage.hasUsage(keyInfo.getKeyUsage(), usage), "此密钥无权进行" + usage.getDesc() + "操作");
return keyInfo;
}
}