@Intercepts( {
@Signature(method = "prepare", type = StatementHandler.class, args = {Connection.class}) })
public class PageInterceptor implements Interceptor {
private static Logger log = Logger.getLogger(PageInterceptor.class);
private String databaseType;//数据库类型,不同的数据库有不同的分页方法
/**
* 拦截后要执行的方法
*/
public Object intercept(Invocation invocation) throws Throwable {
RoutingStatementHandler handler = (RoutingStatementHandler) invocation.getTarget();
//通过反射获取到当前RoutingStatementHandler对象的delegate属性
StatementHandler delegate = (StatementHandler)ReflectUtil.getFieldValue(handler, "delegate");
BoundSql boundSql = delegate.getBoundSql();
Object obj = boundSql.getParameterObject();
if (obj instanceof Page<?>) {
Page<?> page = (Page<?>) obj;
MappedStatement mappedStatement = (MappedStatement)ReflectUtil.getFieldValue(delegate, "mappedStatement");
Connection connection = (Connection)invocation.getArgs()[0];
String sql = boundSql.getSql();
this.setTotalRecord(page,mappedStatement, connection);
String pageSql = this.getPageSql(page, sql);
ReflectUtil.setFieldValue(boundSql, "sql", pageSql);
}
return invocation.proceed();
}
/**
* 拦截器对应的封装原始对象的方法
*/
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
/**
* 设置注册拦截器时设定的属性
*/
public void setProperties(Properties properties) {
this.databaseType = properties.getProperty("databaseType");
}
/**
* 根据page对象获取对应的分页查询Sql语句,这里只做了两种数据库类型,Mysql和Oracle
* 其它的数据库都 没有进行分页
*
* @param page 分页对象
* @param sql 原sql语句
* @return
*/
private String getPageSql(Page<?> page, String sql) {
StringBuffer sqlBuffer = new StringBuffer(sql);
if ("mysql".equalsIgnoreCase(databaseType)) {
return getMysqlPageSql(page, sqlBuffer);
} else if ("oracle".equalsIgnoreCase(databaseType)) {
return getOraclePageSql(page, sqlBuffer);
}
return sqlBuffer.toString();
}
/**
* 获取Mysql数据库的分页查询语句
* @param page 分页对象
* @param sqlBuffer 包含原sql语句的StringBuffer对象
* @return Mysql数据库分页语句
*/
private String getMysqlPageSql(Page<?> page, StringBuffer sqlBuffer) {
sqlBuffer =new StringBuffer(getParamSql(page, sqlBuffer));
//计算第一条记录的位置,Mysql中记录的位置是从0开始的。
int offset = (page.getPageNo() - 1) * page.getPageSize();
sqlBuffer.append(" limit ").append(offset).append(",").append(page.getPageSize());
return sqlBuffer.toString();
}
/**
* 获取Oracle数据库的分页查询语句
* @param page 分页对象
* @param sqlBuffer 包含原sql语句的StringBuffer对象
* @return Oracle数据库的分页查询语句
*/
private String getOraclePageSql(Page<?> page, StringBuffer sqlBuffer) {
//计算第一条记录的位置,Oracle分页是通过rownum进行的,而rownum是从1开始的
int offset = (page.getPageNo() - 1) * page.getPageSize() + 1;
sqlBuffer.insert(0, "select u.*, rownum r from (").append(") u where rownum < ").append(offset + page.getPageSize());
sqlBuffer.insert(0, "select * from (").append(") where r >= ").append(offset);
//上面的Sql语句拼接之后大概是这个样子:
//select * from (select u.*, rownum r from (select * from t_user) u where rownum < 31) where r >= 16
return sqlBuffer.toString();
}
/**
* 给当前的参数对象page设置总记录数
*
* @param page Mapper映射语句对应的参数对象
* @param mappedStatement Mapper映射语句
* @param connection 当前的数据库连接
*/
private void setTotalRecord(Page<?> page,
MappedStatement mappedStatement, Connection connection) {
BoundSql boundSql = mappedStatement.getBoundSql(page);
String sql = boundSql.getSql();
String countSql = this.getCountSql(page,sql);
List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, page);
ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, page, countBoundSql);
//通过connection建立一个countSql对应的PreparedStatement对象。
PreparedStatement pstmt = null;
ResultSet rs = null;
try {
pstmt = connection.prepareStatement(countSql);
//通过parameterHandler给PreparedStatement对象设置参数
parameterHandler.setParameters(pstmt);
rs = pstmt.executeQuery();
if (rs.next()) {
int totalRecord = rs.getInt(1);
//给当前的参数page对象设置总记录数
page.setTotalRecord(totalRecord);
}
} catch (SQLException e) {
e.printStackTrace();
} finally {
try {
if (rs != null)
rs.close();
if (pstmt != null)
pstmt.close();
} catch (SQLException e) {
e.printStackTrace();
}
}
}
/**
* 拼接参数
* @param page
* @param sql
* @return
*/
private String getParamSql(Page<?> page, StringBuffer sqlBuffer){
Map<String, Object> params =page.getParams();
if(null!=params){
boolean first= true;
for(Map.Entry<String, Object> entry :params.entrySet() ){
if(first){
sqlBuffer.append(" where ").append(entry.getKey()).append(" = ").append(entry.getValue());
first=!first;
}else{
sqlBuffer.append(" and ").append(entry.getKey()).append(" = ").append(entry.getValue());
}
}
}
return sqlBuffer.toString();
}
/**
* 根据原Sql语句获取对应的查询总记录数的Sql语句
* @param sql
* @return
*/
private String getCountSql(Page<?> page,String sql) {
String countSql = null;
int index = sql.indexOf("from");
countSql = "select count(*) " + sql.substring(index);
countSql = getParamSql(page, new StringBuffer(countSql));
return countSql;
}
/**
* 利用反射进行操作的一个工具类
*
*/
private static class ReflectUtil {
/**
* 利用反射获取指定对象的指定属性
* @param obj 目标对象
* @param fieldName 目标属性
* @return 目标属性的值
*/
public static Object getFieldValue(Object obj, String fieldName) {
Object result = null;
Field field = ReflectUtil.getField(obj, fieldName);
if (field != null) {
field.setAccessible(true);
try {
result = field.get(obj);
} catch (IllegalArgumentException e) {
log.error(e.getMessage());
e.printStackTrace();
} catch (IllegalAccessException e) {
log.error(e.getMessage());
e.printStackTrace();
}
}
return result;
}
/**
* 利用反射获取指定对象里面的指定属性
* @param obj 目标对象
* @param fieldName 目标属性
* @return 目标字段
*/
private static Field getField(Object obj, String fieldName) {
Field field = null;
for (Class<?> clazz=obj.getClass(); clazz != Object.class; clazz=clazz.getSuperclass()) {
try {
field = clazz.getDeclaredField(fieldName);
break;
} catch (NoSuchFieldException e) {
//这里不用做处理,子类没有该字段可能对应的父类有,都没有就返回null。
}
}
return field;
}
/**
* 利用反射设置指定对象的指定属性为指定的值
* @param obj 目标对象
* @param fieldName 目标属性
* @param fieldValue 目标值
*/
public static void setFieldValue(Object obj, String fieldName,
String fieldValue) {
Field field = ReflectUtil.getField(obj, fieldName);
if (field != null) {
try {
field.setAccessible(true);
field.set(obj, fieldValue);
} catch (IllegalArgumentException e) {
log.error(e.getMessage());
e.printStackTrace();
} catch (IllegalAccessException e) {
log.error(e.getMessage());
e.printStackTrace();
}
}
}
}
}