1 Star 3 Fork 0

it_life/sql注入拦截器

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
SqlInjectInterceptor.java 6.79 KB
一键复制 编辑 原始数据 按行查看 历史
it_life 提交于 2023-07-26 13:30 . update SqlInjectInterceptor.java.
/**
*
*/
package com.java.portal.core.plugin;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Pattern;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.scripting.xmltags.ChooseSqlNode;
import org.apache.ibatis.scripting.xmltags.DynamicSqlSource;
import org.apache.ibatis.scripting.xmltags.SqlNode;
import org.apache.ibatis.scripting.xmltags.StaticTextSqlNode;
import org.apache.ibatis.scripting.xmltags.TextSqlNode;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.util.ReflectionUtils;
/**
* 防止sql脚本注入,拦截占位符${}参数值,可拦截CRUD,默认用于查询
*
* @author yove
*
*/
@Component
@Intercepts({ @Signature(type = Executor.class, method = "query", args = { MappedStatement.class, Object.class,
RowBounds.class, ResultHandler.class }) })
public class SqlInjectInterceptor implements Interceptor {
private static final Logger LOG = LoggerFactory.getLogger(SqlInjectInterceptor.class);
private static Map<String, Object> cache = new ConcurrentHashMap<String, Object>();
public static String[] KEYWORDS = ";|'|\"|*|%|#|--|//|+|=|or|and|like|select|insert|update|delete|alert|drop|truncate|declare|exec|execute|create|xp_|sp_|0x"
.split("\\|");
/**
* 配置sql保留关键字
*
* @param keywords
*/
@Value("${sql.keywords:}")
public void setKeywords(String keywords) {
if (StringUtils.isNotEmpty(keywords)) {
SqlInjectInterceptor.KEYWORDS = keywords.split("\\|");
}
}
@Override
public Object plugin(Object arg0) {
if (arg0 instanceof Executor) {
return Plugin.wrap(arg0, this);
}
return arg0;
}
@Override
public void setProperties(Properties properties) {
}
@Override
public Object intercept(Invocation invocation) throws Throwable {
Object parameterObject = invocation.getArgs()[1];
if (null != parameterObject) {
MetaObject mappedStatement = SystemMetaObject.forObject(invocation.getArgs()[0]);
if (mappedStatement.getValue("sqlSource") instanceof DynamicSqlSource) {
Set<String> parameterKeys = (Set<String>) cache.get((String) mappedStatement.getValue("id"));
if (null == parameterKeys) {
SqlNode rootSqlNode = (SqlNode) mappedStatement.getValue("sqlSource.rootSqlNode");
parameterKeys = parseSqlNode(rootSqlNode);
cache.put((String) mappedStatement.getValue("id"), parameterKeys);
}
if (null != parameterKeys && !parameterKeys.isEmpty()) {
MetaObject parameterMo = SystemMetaObject.forObject(parameterObject);
for (String parameterKey : parameterKeys) {
if (parameterMo.hasGetter(parameterKey) || parameterObject instanceof Map) {
Object value = parameterMo.getValue(parameterKey);
if (null != value && !"".equals(value)) {
parameterMo.setValue(parameterKey, process(value));
}
}
}
}
}
}
return invocation.proceed();
}
/**
* 替换占位符${}参数值中的sql保留关键字为N/A
*
* @param value
* @return
*/
public Object process(Object value) {
for (int i = 0; i < KEYWORDS.length; i++) {
String str = KEYWORDS[i];
if (i > 9) {
if (Pattern.compile("\\s+").matcher(value.toString()).find()) {
for (String s : value.toString().split("\\s+")) {
if (StringUtils.equalsIgnoreCase(s, str)) {
value = value.toString().replaceAll(str, "N/A");
LOG.warn("sql脚本中特殊字符【{}】已被过滤", str);
}
}
}
} else if (value.toString().toLowerCase().contains(str)) {
value = value.toString().replaceAll(Pattern.quote(str), "N/A");
LOG.warn("sql脚本中特殊字符【{}】已被过滤", str);
}
}
return value;
}
/**
* 解析sqlNode中占位符${}的参数key
*
* @param sqlNode
* @return
* @throws Exception
*/
private Set<String> parseSqlNode(SqlNode sqlNode) throws Exception {
Set<String> parameterKeys = new HashSet<String>();
if (sqlNode instanceof TextSqlNode) {
Field textField = (Field) cache.get("text");
if (null == textField) {
textField = ReflectionUtils.findField(sqlNode.getClass(), "text");
textField.setAccessible(true);
cache.put("text", textField);
}
String text = (String) textField.get(sqlNode);
String[] texts = StringUtils.substringsBetween(text, "${", "}");
parameterKeys.addAll(Arrays.asList(texts));
return parameterKeys;
}
Field contentsField = null, defaultSqlNodeField = null, ifSqlNodesField = null;
if (sqlNode instanceof ChooseSqlNode) {
defaultSqlNodeField = (Field) cache.get("defaultSqlNode");
if (null == defaultSqlNodeField) {
defaultSqlNodeField = ReflectionUtils.findField(sqlNode.getClass(), "defaultSqlNode");
defaultSqlNodeField.setAccessible(true);
cache.put("defaultSqlNode", defaultSqlNodeField);
}
ifSqlNodesField = (Field) cache.get("ifSqlNodes");
if (null == ifSqlNodesField) {
ifSqlNodesField = ReflectionUtils.findField(sqlNode.getClass(), "ifSqlNodes");
ifSqlNodesField.setAccessible(true);
cache.put("ifSqlNodes", ifSqlNodesField);
}
} else if (!(sqlNode instanceof StaticTextSqlNode)) {
contentsField = ReflectionUtils.findField(sqlNode.getClass(), "contents");// 不能缓存,无法判断sqlNode类型
if (null != contentsField) {
contentsField.setAccessible(true);
}
}
if (null != contentsField || null != defaultSqlNodeField || null != ifSqlNodesField) {
if (null != defaultSqlNodeField) {
Object contents = defaultSqlNodeField.get(sqlNode);
if (null != contents) {
parameterKeys.addAll(parseSqlNode((SqlNode) contents));
}
}
Object contents = null;
if (null != contentsField) {
contents = contentsField.get(sqlNode);
} else if (null != ifSqlNodesField) {
contents = ifSqlNodesField.get(sqlNode);
}
if (contents instanceof Collection) {
List<SqlNode> mixedSqlNode = (List<SqlNode>) contents;
for (SqlNode sql : mixedSqlNode) {
parameterKeys.addAll(parseSqlNode(sql));
}
} else if (null != contents) {
parameterKeys.addAll(parseSqlNode((SqlNode) contents));
}
}
return parameterKeys;
}
}
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Java
1
https://gitee.com/yove/sqlInject.git
git@gitee.com:yove/sqlInject.git
yove
sqlInject
sql注入拦截器
master

搜索帮助