package com.sunyard.chsm.pool; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.sunyard.chsm.auth.AppUser; import com.sunyard.chsm.constant.SecurityConstant; import com.sunyard.chsm.enums.DeviceTmkStatus; import com.sunyard.chsm.mapper.CryptoServiceDeviceGroupMapper; import com.sunyard.chsm.mapper.SpDeviceMapper; import com.sunyard.chsm.model.dto.DeviceCheckRes; import com.sunyard.chsm.model.entity.CryptoServiceDeviceGroup; import com.sunyard.chsm.model.entity.Device; import com.sunyard.chsm.sdf.adapter.SdfApiAdapter; import com.sunyard.chsm.sdf.adapter.SdfApiAdapterFactory; import com.sunyard.chsm.sdf.model.EccCipher; import com.sunyard.chsm.service.TmkService; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.commons.pool2.impl.GenericObjectPool; import org.apache.commons.pool2.impl.GenericObjectPoolConfig; import org.springframework.beans.factory.InitializingBean; import org.springframework.stereotype.Component; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.web.context.request.RequestAttributes; import org.springframework.web.context.request.RequestContextHolder; import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; /** * @author liulu * @since 2024/12/10 */ @Slf4j @Component @RequiredArgsConstructor public class DeviceManager implements InitializingBean { private static final Map ROUND_MAP = new HashMap<>(); private Map> deviceMap = new HashMap<>(); private boolean enableSoftDevice = false; private TMKContext softContext; private final TmkService tmkService; private final SpDeviceMapper spDeviceMapper; private final CryptoServiceDeviceGroupMapper cryptoServiceDeviceGroupMapper; public void returnContextToPool(TMKContext context) { Optional.ofNullable(context.getPool()) .ifPresent(it -> { context.setPool(null); it.returnObject(context); }); } public TMKContext chooseOne() { RequestAttributes attributes = RequestContextHolder.currentRequestAttributes(); AppUser user = (AppUser) attributes.getAttribute(SecurityConstant.ATTRIBUTE_APP_USER, RequestAttributes.SCOPE_REQUEST); Assert.notNull(user, "登录用户不能为空"); //noinspection unchecked List serviceIds = (List) attributes.getAttribute("used_service_ids", RequestAttributes.SCOPE_REQUEST); Assert.isTrue(!CollectionUtils.isEmpty(serviceIds), "应用: " + user.getName() + "没有可用服务"); AtomicInteger atomicInteger = ROUND_MAP.computeIfAbsent(user.getAppId(), k -> new AtomicInteger(1)); if (atomicInteger.get() > Integer.MAX_VALUE - 10000) { atomicInteger.set(1); } List contexts = new ArrayList<>(); for (Long serviceId : serviceIds) { Optional.ofNullable(deviceMap.get(serviceId)) .ifPresent(contexts::addAll); } DeviceContext device = getNextDevice(contexts, atomicInteger.getAndIncrement()); TMKContext tmkContext; if (device == null) { tmkContext = getSoftContext(); } else { try { GenericObjectPool pool = device.getPool(); tmkContext = pool.borrowObject(2000); tmkContext.setPool(pool); } catch (Exception e) { tmkContext = getSoftContext(); } } Assert.notNull(tmkContext, "应用: " + user.getName() + "没有可用的密码设备"); return tmkContext; } public synchronized TMKContext getSoftContext() { if (!enableSoftDevice) { return null; } if (Objects.nonNull(softContext)) { return softContext; } byte[] softEncTmk = tmkService.getSoftDeviceEncTmk(); if (softEncTmk == null || softEncTmk.length == 0) { return null; } SdfApiAdapter bcAdapter = SdfApiAdapterFactory.getBcAdapter(); String hk = bcAdapter.importKeyWithISKECC("", 1, EccCipher.fromBytes(softEncTmk)); softContext = new TMKContext(); softContext.setSdfApiAdapter(bcAdapter); softContext.setKeyHandle(hk); return softContext; } public static DeviceContext getNextDevice(List devices, int totalCalls) { if (devices == null || devices.isEmpty()) { return null; } int totalWeight = 0; for (DeviceContext device : devices) { totalWeight += device.getWeight(); } int index = totalCalls % totalWeight; for (DeviceContext device : devices) { if (index < device.getWeight()) { return device; } index -= device.getWeight(); } return null; // 理论上不会到这里 } private void syncDevice() { log.debug(">>>>>>>>>>>>>>> start sync device <<<<<<<<<<<<<<<"); enableSoftDevice = tmkService.isEnableSoftDevice(); List devices = spDeviceMapper.selectList( new LambdaQueryWrapper() .eq(Device::getTmkStatus, DeviceTmkStatus.finished) .gt(Device::getGroupId, 0) ); if (CollectionUtils.isEmpty(devices)) { log.info("no device for sync ..."); deviceMap.clear(); return; } Map> groupDeviceMap = devices.stream().collect(Collectors.groupingBy(Device::getGroupId)); List serviceDeviceGroups = cryptoServiceDeviceGroupMapper .selectList(new LambdaQueryWrapper() .in(CryptoServiceDeviceGroup::getDeviceGroupId, groupDeviceMap.keySet())); if (CollectionUtils.isEmpty(serviceDeviceGroups)) { deviceMap.clear(); return; } Map> waitSyncMap = serviceDeviceGroups.stream() .collect(Collectors.toMap(CryptoServiceDeviceGroup::getServiceId, it -> groupDeviceMap.get(it.getDeviceGroupId()))); for (Map.Entry> entry : waitSyncMap.entrySet()) { deviceMap.compute(entry.getKey(), (k, old) -> { if (CollectionUtils.isEmpty(old)) { return entry.getValue().stream().map(this::mapToContext).filter(Objects::nonNull).collect(Collectors.toList()); } List newSerials = entry.getValue().stream() .map(Device::getDeviceSerial) .collect(Collectors.toList()); List oldSerials = old.stream() .map(DeviceContext::getDeviceSerial) .collect(Collectors.toList()); List nc = old.stream() .filter(it -> newSerials.contains(it.getDeviceSerial())) .collect(Collectors.toList()); List waitSync = entry.getValue().stream() .filter(it -> !oldSerials.contains(it.getDeviceSerial())) .collect(Collectors.toList()); nc.addAll(waitSync.stream().map(this::mapToContext).filter(Objects::nonNull).collect(Collectors.toList())); return nc; }); } } private DeviceContext mapToContext(Device device) { try { Assert.hasText(device.getEncTmk(), "TMK 状态异常"); DeviceCheckRes checkRes = tmkService.checkDevice(device); if (!Objects.equals(checkRes.getDeviceSerial(), device.getDeviceSerial()) || !Objects.equals(checkRes.getPubKey(), device.getPubKey())) { return null; } DeviceContext context = new DeviceContext(); context.setIp(device.getServiceIp()); context.setPort(device.getServicePort()); context.setModel(device.getManufacturerModel()); context.setEncKeyIdx(device.getEncKeyIdx()); context.setAccessCredentials(device.getAccessCredentials()); context.setDeviceSerial(device.getDeviceSerial()); context.setPubKey(device.getPubKey()); context.setEncTmk(device.getEncTmk()); context.setWeight(device.getWeight()); context.setSdfApiAdapter(checkRes.getSdfApiAdapter()); GenericObjectPoolConfig config = new GenericObjectPoolConfig<>(); config.setMaxWait(Duration.ofSeconds(5)); config.setJmxEnabled(false); config.setMinIdle(2); config.setTimeBetweenEvictionRuns(Duration.ofMinutes(1)); config.setTestWhileIdle(true); TMKContextFactory tenantTMKContextFactory = new TMKContextFactory(checkRes.getSdfApiAdapter(), context); GenericObjectPool pool = new GenericObjectPool<>(tenantTMKContextFactory, config); context.setPool(pool); return context; } catch (Exception ex) { log.warn("device conn error: {}", device, ex); return null; } } @Override public void afterPropertiesSet() throws Exception { Executors.newSingleThreadScheduledExecutor() .scheduleWithFixedDelay(this::syncDevice, 0L, 5L, TimeUnit.MINUTES); } }