字段加解密
定义注解
java
package org.elsfs.cloud.common.mybatis.annotation;
import java.lang.annotation.ElementType;
import java.lang.annotation.Inherited;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import org.elsfs.cloud.common.mybatis.algorithm.ICipherAlgorithm;
/**
* 字段加解密
*
* @author zeng
*/
@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Inherited
public @interface CipherField {
/**
* 加减密提供者
*
* @return 加减密提供者
*/
Class<? extends ICipherAlgorithm> value();
}
加密算法接口
java
package org.elsfs.cloud.common.mybatis.algorithm;
/**
* 加密算法接口
*
* @author zeng
*/
public interface ICipherAlgorithm {
/**
* 加密方法
*
* @param value 明文
* @return 返回加密后的密文
*/
String encrypt(String value);
/**
* 解密方法
*
* @param value 密文
* @return 返回解密后的明文, 如果算法不可逆,直接返回密文
*/
String decrypt(String value);
}
数据库编码开关
java
package org.elsfs.cloud.common.mybatis.annotation;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* 数据库编码开关
*
* @author zeng
*/
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.METHOD})
public @interface EnableCipher {
/**
* 定义参数的加密或解密处理方式。 该注解用于指定参数是否需要进行加密或解密处理,默认为不处理。
*
* <p>CipherType 加密或解密类型,默认为不处理(CipherType.NONE)。
*
* @return 返回CipherType枚举类型,指定参数的处理方式。
*/
CipherType parameter() default CipherType.NONE;
/**
* 定义结果处理的形式,可以选择加密、解密或者不处理。 默认情况下不进行处理。
*
* @return CipherType 结果处理类型,枚举类型,包括加密、解密和不处理(默认)三种选项。
*/
CipherType result() default CipherType.NONE;
/** 参数处理形式 */
enum CipherType {
/** 不处理 */
NONE,
/** 加密 */
ENCRYPT,
/** 解密 */
DECRYPT,
}
}
实现
java
package org.elsfs.cloud.common.mybatis.interceptor;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.cache.impl.PerpetualCache;
import org.apache.ibatis.executor.CachingExecutor;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.elsfs.cloud.common.mybatis.algorithm.ICipherAlgorithm;
import org.elsfs.cloud.common.mybatis.annotation.CipherField;
import org.elsfs.cloud.common.mybatis.annotation.EnableCipher;
import org.springframework.util.ReflectionUtils;
/**
* 基于mybatis拦截器实现查询字段脱敏,敏感字段加解密
*
* @author zeng
*/
@Intercepts({
@Signature(
type = Executor.class,
method = "update",
args = {MappedStatement.class, Object.class}),
@Signature(
type = Executor.class,
method = "query",
args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
@Signature(
type = Executor.class,
method = "query",
args = {
MappedStatement.class,
Object.class,
RowBounds.class,
ResultHandler.class,
CacheKey.class,
BoundSql.class
}),
})
@Slf4j
public class MybatisSensitiveInterceptor implements Interceptor {
private final ThreadLocal<Map<Object, Map<Field, CipherValue>>> cipherFieldMapLocal =
ThreadLocal.withInitial(ConcurrentHashMap::new);
private static final String SELECT_KEY = "!selectKey";
private static final int EXECUTOR_PARAMETER_COUNT_4 = 4;
private static final int MAPPED_STATEMENT_INDEX = 0;
private static final int PARAMETER_INDEX = 1;
private static final int ROW_BOUNDS_INDEX = 2;
private static final int CACHE_KEY_INDEX = 4;
private final Map<String, ICipherAlgorithm> algorithmMap = new ConcurrentHashMap<>(2);
@Override
public Object intercept(Invocation invocation) throws Throwable {
Object[] args = invocation.getArgs();
MappedStatement statement = (MappedStatement) args[MAPPED_STATEMENT_INDEX];
Object parameter = args[PARAMETER_INDEX];
// 处理selectKey的特殊情况
handleSelectKey(statement, parameter);
EnableCipher enableCipher = getEnableCipher(statement);
if (enableCipher == null) {
return invocation.proceed();
}
// 获取所有的注解参数字段,根据每个字段对应的算法进行加密或解密
handleParameter(parameter, enableCipher.parameter());
// 判断是否命中缓存
// TODO 缓存处理
// boolean hitCache = hitCache(invocation, parameter);
// 执行proceed
Object proceed;
try {
proceed = invocation.proceed();
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
// 还原参数
revertParameter(parameter);
}
// 执行结果处理
// TODO 缓存处理
// return hitCache ? proceed : handleResult(proceed, enableCipher.result());
return handleResult(proceed, enableCipher.result());
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {}
/**
* 获取mapper上的EnableCipher注解
*
* @param statement MappedStatement对象
* @return 返回EnableCipher注解对象
*/
private EnableCipher getEnableCipher(MappedStatement statement) {
String namespace = statement.getId();
String className = namespace.substring(0, namespace.lastIndexOf("."));
String methodName = statement.getId().substring(statement.getId().lastIndexOf(".") + 1);
Method[] methods;
try {
methods = Class.forName(className).getMethods();
for (Method method : methods) {
if (method.getName().equals(methodName)) {
if (method.isAnnotationPresent(EnableCipher.class)) {
return method.getAnnotation(EnableCipher.class);
}
}
}
} catch (ClassNotFoundException e) {
LOGGER.error("get @EnableCipher from {} error!", namespace);
}
return null;
}
/**
* 处理selectKey
*
* @param statement statement
* @param parameter parameter
*/
private void handleSelectKey(MappedStatement statement, Object parameter) {
SqlCommandType commandType = statement.getSqlCommandType();
if (commandType == SqlCommandType.SELECT && statement.getId().endsWith(SELECT_KEY)) {
revertParameter(parameter);
}
}
/**
* 处理参数 是否加解密
*
* @param parameter 输入参数
* @param cipherType 加解密方式
*/
private void handleParameter(Object parameter, EnableCipher.CipherType cipherType) {
if (cipherType == EnableCipher.CipherType.NONE) {
return;
}
Map<Object, Map<Field, CipherValue>> cipherMap = new HashMap<>();
if (parameter instanceof Map<?, ?> parameterMap) {
Map<Object, Object> map = filterRepeatValueMap(parameterMap);
map.forEach(
(k, v) -> {
Map<Object, Map<Field, CipherValue>> valueMap;
if (v instanceof Collection<?> v1) {
valueMap = handleCipher(v1, cipherType);
} else {
valueMap = handleCipher(Collections.singleton(v), cipherType);
}
if (valueMap != null && !valueMap.isEmpty()) {
cipherMap.putAll(valueMap);
}
});
} else {
Map<Object, Map<Field, CipherValue>> valueMap =
handleCipher(Collections.singleton(parameter), cipherType);
if (valueMap != null && !valueMap.isEmpty()) {
cipherMap.putAll(valueMap);
}
}
// ThreadLocal临时保存处理过的字段值
if (!cipherMap.isEmpty()) {
cipherFieldMapLocal.get().putAll(cipherMap);
// 可以打印参数加密前和加密后的值 cipherMap
}
}
/**
* 处理结果
*
* @param result 执行结果对象
* @param cipherType 加解密方式
* @return 返回处理后的结果对象
*/
private Object handleResult(Object result, EnableCipher.CipherType cipherType) {
if (cipherType == EnableCipher.CipherType.NONE) {
return result;
}
if (result instanceof Collection<?> r) {
handleCipher(r, cipherType);
} else {
handleCipher(Collections.singleton(result), cipherType);
}
return result;
}
/**
* 还原参数
*
* @param parameter p
*/
private void revertParameter(Object parameter) {
final Map<Object, Map<Field, CipherValue>> cipherFieldMap = cipherFieldMapLocal.get();
if (cipherFieldMap.isEmpty()) {
return;
}
if (parameter instanceof Map<?, ?> map) {
Map<Object, Object> parameterMap = filterRepeatValueMap(map);
parameterMap.forEach(
(k, v) -> {
if (v instanceof Collection<?> v1) {
v1.stream()
.filter(Objects::nonNull)
.forEach(
obj -> {
Map<Field, CipherValue> valueMap = cipherFieldMap.get(obj);
if (Objects.nonNull(valueMap)) {
valueMap.forEach(
(field, cipher) ->
ReflectionUtils.setField(field, obj, cipher.getBefore()));
}
});
} else {
Map<Field, CipherValue> valueMap = cipherFieldMap.get(v);
if (Objects.nonNull(valueMap)) {
valueMap.forEach(
(field, cipher) -> ReflectionUtils.setField(field, v, cipher.getBefore()));
}
}
});
} else {
Map<Field, CipherValue> valueMap = cipherFieldMap.get(parameter);
if (Objects.nonNull(valueMap)) {
valueMap.forEach(
(field, cipher) -> ReflectionUtils.setField(field, parameter, cipher.getBefore()));
}
}
cipherFieldMap.clear();
}
/**
* 查询语句是否命中缓存
*
* @param invocation 拦截器方法对象
* @param parameter 处理过后的参数对象
* @return 是否命中缓存
*/
private boolean hitCache(Invocation invocation, Object parameter) throws IllegalAccessException {
Object[] args = invocation.getArgs();
MappedStatement mappedStatement = (MappedStatement) args[MAPPED_STATEMENT_INDEX];
Executor executor = (Executor) invocation.getTarget();
SqlCommandType sqlCommandType = mappedStatement.getSqlCommandType();
// 非查询语句直接返回false
if (!SqlCommandType.SELECT.equals(sqlCommandType)) {
return false;
}
RowBounds rowBounds = (RowBounds) args[ROW_BOUNDS_INDEX];
BoundSql boundSql;
CacheKey cacheKey;
if (args.length == EXECUTOR_PARAMETER_COUNT_4) {
boundSql = mappedStatement.getBoundSql(parameter);
cacheKey = executor.createCacheKey(mappedStatement, parameter, rowBounds, boundSql);
} else {
cacheKey = (CacheKey) args[CACHE_KEY_INDEX];
}
Executor baseExecutor;
if (executor instanceof CachingExecutor) {
Field field = ReflectionUtils.findField(CachingExecutor.class, "delegate");
assert field != null;
field.setAccessible(true);
baseExecutor = (Executor) field.get(executor);
} else {
baseExecutor = (Executor) invocation.getTarget();
}
Field field = ReflectionUtils.findField(CachingExecutor.class, "localCache");
assert field != null;
field.setAccessible(true);
PerpetualCache localCache = (PerpetualCache) field.get(baseExecutor);
return Objects.nonNull(localCache.getObject(cacheKey));
}
/**
* 获取算法对象
*
* @param value 加解密算法子类
* @return 返回算法对象
*/
private ICipherAlgorithm getCipherAlgorithm(Class<? extends ICipherAlgorithm> value) {
String canonicalName = value.getCanonicalName();
if (algorithmMap.containsKey(canonicalName)) {
return algorithmMap.get(canonicalName);
}
try {
ICipherAlgorithm algorithm = value.getDeclaredConstructor().newInstance();
algorithmMap.put(value.getName(), algorithm);
return algorithm;
} catch (Exception e) {
throw new RuntimeException("init ICipherAlgorithm error", e);
}
}
/**
* 加解密操作对象
*
* @param collection 输入参数
* @param cipherType 加解密方式
* @return 返回已处理字段的处理前后值
*/
private Map<Object, Map<Field, CipherValue>> handleCipher(
Collection<?> collection, EnableCipher.CipherType cipherType) {
if (collection == null || collection.isEmpty()) {
return null;
}
// 遍历参数,处理加解密
Map<Object, Map<Field, CipherValue>> result = new HashMap<>();
collection.forEach(
object -> {
Map<Field, CipherValue> valueMap = new HashMap<>();
this.getFields(object).stream()
.filter(
field ->
field.isAnnotationPresent(CipherField.class)
&& field.getType() == String.class)
.forEach(
field -> {
CipherField cipherField = field.getAnnotation(CipherField.class);
ICipherAlgorithm algorithm = getCipherAlgorithm(cipherField.value());
String value = (String) getField(field, object);
if (Objects.nonNull(value)) {
String algorithmValue = null;
if (cipherType == EnableCipher.CipherType.ENCRYPT) {
algorithmValue = algorithm.encrypt(value);
}
if (cipherType == EnableCipher.CipherType.DECRYPT) {
algorithmValue = algorithm.decrypt(value);
}
if (Objects.nonNull(algorithmValue)) {
ReflectionUtils.setField(field, object, algorithmValue);
valueMap.put(field, new CipherValue(value, algorithmValue));
}
}
});
if (!valueMap.isEmpty()) {
result.put(object, valueMap);
}
});
return result;
}
private Map<Object, Object> filterRepeatValueMap(Map<?, ?> parameter) {
Set<Integer> hashCodeSet = new HashSet<>();
return (parameter)
.entrySet().stream()
.filter(e -> Objects.nonNull(e.getValue()))
.filter(
r -> {
if (!hashCodeSet.contains(r.getValue().hashCode())) {
hashCodeSet.add(r.getValue().hashCode());
return true;
}
return false;
})
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}
private List<Field> getFields(Object obj) {
List<Field> fieldList = new ArrayList<>();
Class<?> tempClass = obj.getClass();
while (tempClass != null) {
fieldList.addAll(Arrays.asList(tempClass.getDeclaredFields()));
tempClass = tempClass.getSuperclass();
}
return fieldList;
}
private Object getField(Field field, Object obj) {
ReflectionUtils.makeAccessible(field);
return ReflectionUtils.getField(field, obj);
}
@Getter
static class CipherValue {
String before;
String after;
public CipherValue(String before, String after) {
// 处理前
this.before = before;
// 处理后
this.after = after;
}
}
}
配置 bean
java
/**
* 创建并返回一个MybatisSensitiveInterceptor实例。 这个方法没有参数。
*
* @return MybatisSensitiveInterceptor - 返回一个MybatisSensitiveInterceptor实例,用于敏感信息的拦截处理。
*/
@Bean
public MybatisSensitiveInterceptor mybatisSensitiveInterceptor() {
return new MybatisSensitiveInterceptor();
}
使用
实现加密接口
java
/** 加密方式 */
public class NameCipherAlgorithm implements ICipherAlgorithm {
@Override
public String encrypt(String value) {
return value.concat("123");
}
@Override
public String decrypt(String value) {
return value.substring(0, value.length() - 3);
}
}
定义 mapper
java
/** UserMapper */
@Mapper
public interface UserMapper extends BaseMapper<User> {
/**
* 添加加密数据
*
* @param entity 实体
* @return 成功条数
*/
@Override
@EnableCipher(parameter = EnableCipher.CipherType.ENCRYPT)
int insert(User entity);
/**
* 获取添加的加密数据
*
* @param id id
* @return 获取添加的加密数据
*/
@Override
User selectById(Serializable id);
/**
* 获取解密数据
*
* @param queryWrapper 条件
* @return 解密数据
*/
@EnableCipher(result = EnableCipher.CipherType.DECRYPT)
@Override
User selectOne(@Param("ew") Wrapper<User> queryWrapper);
}
java
@SpringBootTest
@MapperScan("org.elsfs.cloud.common.mybatis.interceptor")
@Import(MybatisPlusConfiguration.class)
class MybatisSensitiveInterceptorTest {
@Autowired private UserMapper userMapper;
/** 数据加密测试 */
@Test
public void testEncrypt() {
var name = "Elsfs-Cloud";
userMapper.insert(User.builder().id("100").name(name).build());
User user = userMapper.selectById("100");
Assertions.assertEquals(user.getName(), new NameCipherAlgorithm().encrypt(name));
}
}