chsm-server/chsm-web-server/src/main/java/com/sunyard/chsm/pool/DeviceManager.java
2024-12-12 10:05:37 +08:00

245 lines
9.7 KiB
Java

package com.sunyard.chsm.pool;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.sunyard.chsm.auth.AppUser;
import com.sunyard.chsm.constant.SecurityConstant;
import com.sunyard.chsm.enums.DeviceTmkStatus;
import com.sunyard.chsm.mapper.CryptoServiceDeviceGroupMapper;
import com.sunyard.chsm.mapper.SpDeviceMapper;
import com.sunyard.chsm.model.dto.DeviceCheckRes;
import com.sunyard.chsm.model.entity.CryptoServiceDeviceGroup;
import com.sunyard.chsm.model.entity.Device;
import com.sunyard.chsm.sdf.adapter.SdfApiAdapter;
import com.sunyard.chsm.sdf.adapter.SdfApiAdapterFactory;
import com.sunyard.chsm.sdf.model.EccCipher;
import com.sunyard.chsm.service.TmkService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.pool2.impl.GenericObjectPool;
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
/**
* @author liulu
* @since 2024/12/10
*/
@Slf4j
@Component
@RequiredArgsConstructor
public class DeviceManager implements InitializingBean {
private static final Map<Long, AtomicInteger> ROUND_MAP = new HashMap<>();
private Map<Long, List<DeviceContext>> deviceMap = new HashMap<>();
private boolean enableSoftDevice = false;
private TMKContext softContext;
private final TmkService tmkService;
private final SpDeviceMapper spDeviceMapper;
private final CryptoServiceDeviceGroupMapper cryptoServiceDeviceGroupMapper;
public void returnContextToPool(TMKContext context) {
Optional.ofNullable(context.getPool())
.ifPresent(it -> {
context.setPool(null);
it.returnObject(context);
});
}
public TMKContext chooseOne() {
RequestAttributes attributes = RequestContextHolder.currentRequestAttributes();
AppUser user = (AppUser) attributes.getAttribute(SecurityConstant.ATTRIBUTE_APP_USER, RequestAttributes.SCOPE_REQUEST);
Assert.notNull(user, "登录用户不能为空");
//noinspection unchecked
List<Long> serviceIds = (List<Long>) attributes.getAttribute("used_service_ids", RequestAttributes.SCOPE_REQUEST);
Assert.isTrue(!CollectionUtils.isEmpty(serviceIds), "应用: " + user.getName() + "没有可用服务");
AtomicInteger atomicInteger = ROUND_MAP.computeIfAbsent(user.getAppId(), k -> new AtomicInteger(1));
if (atomicInteger.get() > Integer.MAX_VALUE - 10000) {
atomicInteger.set(1);
}
List<DeviceContext> contexts = new ArrayList<>();
for (Long serviceId : serviceIds) {
Optional.ofNullable(deviceMap.get(serviceId))
.ifPresent(contexts::addAll);
}
DeviceContext device = getNextDevice(contexts, atomicInteger.getAndIncrement());
TMKContext tmkContext;
if (device == null) {
tmkContext = getSoftContext();
} else {
try {
GenericObjectPool<TMKContext> pool = device.getPool();
tmkContext = pool.borrowObject(2000);
tmkContext.setPool(pool);
} catch (Exception e) {
tmkContext = getSoftContext();
}
}
Assert.notNull(tmkContext, "应用: " + user.getName() + "没有可用的密码设备");
return tmkContext;
}
public synchronized TMKContext getSoftContext() {
if (!enableSoftDevice) {
return null;
}
if (Objects.nonNull(softContext)) {
return softContext;
}
byte[] softEncTmk = tmkService.getSoftDeviceEncTmk();
if (softEncTmk == null || softEncTmk.length == 0) {
return null;
}
SdfApiAdapter bcAdapter = SdfApiAdapterFactory.getBcAdapter();
String hk = bcAdapter.importKeyWithISKECC("", 1, EccCipher.fromBytes(softEncTmk));
softContext = new TMKContext();
softContext.setSdfApiAdapter(bcAdapter);
softContext.setKeyHandle(hk);
return softContext;
}
public static DeviceContext getNextDevice(List<DeviceContext> devices, int totalCalls) {
if (devices == null || devices.isEmpty()) {
return null;
}
int totalWeight = 0;
for (DeviceContext device : devices) {
totalWeight += device.getWeight();
}
int index = totalCalls % totalWeight;
for (DeviceContext device : devices) {
if (index < device.getWeight()) {
return device;
}
index -= device.getWeight();
}
return null; // 理论上不会到这里
}
private void syncDevice() {
log.debug(">>>>>>>>>>>>>>> start sync device <<<<<<<<<<<<<<<");
enableSoftDevice = tmkService.isEnableSoftDevice();
List<Device> devices = spDeviceMapper.selectList(
new LambdaQueryWrapper<Device>()
.eq(Device::getTmkStatus, DeviceTmkStatus.finished)
.gt(Device::getGroupId, 0)
);
if (CollectionUtils.isEmpty(devices)) {
log.info("no device for sync ...");
deviceMap.clear();
return;
}
Map<Long, List<Device>> groupDeviceMap = devices.stream().collect(Collectors.groupingBy(Device::getGroupId));
List<CryptoServiceDeviceGroup> serviceDeviceGroups = cryptoServiceDeviceGroupMapper
.selectList(new LambdaQueryWrapper<CryptoServiceDeviceGroup>()
.in(CryptoServiceDeviceGroup::getDeviceGroupId, groupDeviceMap.keySet()));
if (CollectionUtils.isEmpty(serviceDeviceGroups)) {
deviceMap.clear();
return;
}
Map<Long, List<Device>> waitSyncMap = serviceDeviceGroups.stream()
.collect(Collectors.toMap(CryptoServiceDeviceGroup::getServiceId,
it -> groupDeviceMap.get(it.getDeviceGroupId())));
for (Map.Entry<Long, List<Device>> entry : waitSyncMap.entrySet()) {
deviceMap.compute(entry.getKey(), (k, old) -> {
if (CollectionUtils.isEmpty(old)) {
return entry.getValue().stream().map(this::mapToContext).filter(Objects::nonNull).collect(Collectors.toList());
}
List<String> newSerials = entry.getValue().stream()
.map(Device::getDeviceSerial)
.collect(Collectors.toList());
List<String> oldSerials = old.stream()
.map(DeviceContext::getDeviceSerial)
.collect(Collectors.toList());
List<DeviceContext> nc = old.stream()
.filter(it -> newSerials.contains(it.getDeviceSerial()))
.collect(Collectors.toList());
List<Device> waitSync = entry.getValue().stream()
.filter(it -> !oldSerials.contains(it.getDeviceSerial()))
.collect(Collectors.toList());
nc.addAll(waitSync.stream().map(this::mapToContext).filter(Objects::nonNull).collect(Collectors.toList()));
return nc;
});
}
}
private DeviceContext mapToContext(Device device) {
try {
Assert.hasText(device.getEncTmk(), "TMK 状态异常");
DeviceCheckRes checkRes = tmkService.checkDevice(device);
if (!Objects.equals(checkRes.getDeviceSerial(), device.getDeviceSerial())
|| !Objects.equals(checkRes.getPubKey(), device.getPubKey())) {
return null;
}
DeviceContext context = new DeviceContext();
context.setIp(device.getServiceIp());
context.setPort(device.getServicePort());
context.setModel(device.getManufacturerModel());
context.setEncKeyIdx(device.getEncKeyIdx());
context.setAccessCredentials(device.getAccessCredentials());
context.setDeviceSerial(device.getDeviceSerial());
context.setPubKey(device.getPubKey());
context.setEncTmk(device.getEncTmk());
context.setWeight(device.getWeight());
context.setSdfApiAdapter(checkRes.getSdfApiAdapter());
GenericObjectPoolConfig<TMKContext> config = new GenericObjectPoolConfig<>();
config.setMaxWait(Duration.ofSeconds(5));
config.setJmxEnabled(false);
config.setMinIdle(2);
config.setTimeBetweenEvictionRuns(Duration.ofMinutes(1));
config.setTestWhileIdle(true);
TMKContextFactory tenantTMKContextFactory = new TMKContextFactory(checkRes.getSdfApiAdapter(), context);
GenericObjectPool<TMKContext> pool = new GenericObjectPool<>(tenantTMKContextFactory, config);
context.setPool(pool);
return context;
} catch (Exception ex) {
log.warn("device conn error: {}", device, ex);
return null;
}
}
@Override
public void afterPropertiesSet() throws Exception {
Executors.newSingleThreadScheduledExecutor()
.scheduleWithFixedDelay(this::syncDevice, 0L, 5L, TimeUnit.MINUTES);
}
}