springboot动态多数据源切换,解决多数据源事务问题

50

@[toc]

1.业务场景

通过分库实现不同租户的业务数据隔离,在笔者的公司是一个集团数据库作为一个master库,集团拥有众多的分公司,每一个分公司都有一个slave库,从而做到集团的数据和各个分公司的数据库的数据隔离

2.主要思路

  1. 项目默认使用加载集团master库(库中包含子公司信息表,存储了每一个分公司数据库的连接信息)来启动项目
  2. 监听项目的启动,当项目启动时,读取master库中的子公司的信息表,将子公司的数据库连接读取到系统中
  3. 【重点】继承AbstractRoutingDataSource类,重写determineCurrentLookupKey方法的返回结果以达到在业务中动态切换数据库

    3.核心抽象类AbstractRoutingDataSource分析

    e6d2907d2b92d8787f917d0afbc60d73.png
    该类的getConnection是获取数据库连接的核心方法,determineTargetDataSource方法是connection的来源
    abac3ee669a0650a06f43c7ba9202b21.png
    resolvedDataSrouces是用来保存多个数据源的Map,determineCurrentLookupKey方法返回的就是多数据源Map的Key,所以修改determineCurrentLookupKey即可动态切换数据源

    4.开始编码

    多数据源获取规则配置
    public class MultiRouteDataSource extends AbstractRoutingDataSource {
    
    /**
     * 抽象类AbstractRoutingDataSource中的resolvedDataSources存储了最终解析后的数据源
     * 由于没有提供get的方法,所以此处自己维护一个Map保存数据源,如果后续业务需要获取数据源可以从此处获取
     */
    private static Map<Object, Object> targetDataSources = new HashMap<>();
    
    /**
     * 重写获取数据源的逻辑
     * DataSourceContext中维护一个ThreadLocal用于保存当前西安城正在使用的数据源的key
     *
     * @return 想要获取到的数据源的key
     */
    @Override
    protected Object determineCurrentLookupKey() {
        // 通过绑定线程的数据源上下文实现多数据源的动态切换
        return DataSourceContext.getDataSource();
    }
    
    /**
     * 动态增加数据源
     * 当项目启动时,用该方法添加读取到的子公司数据源
     *
     * @param name       数据源名称
     * @param dataSource 数据源属性
     */
    public synchronized void addDataSource(String name, HikariDataSource dataSource) {
        try {
            // 对该数据源进行解析验证,验证后的数据源才会保存到resolvedDataSources中
            Map<Object, Object> targetMap = MultiRouteDataSource.targetDataSources;
            targetMap.put(name, dataSource);
            this.afterPropertiesSet();
        } catch (Exception e) {
            logger.error(e.getMessage());
        }
    }
    

}

public class DataSourceContext {

/**
 * 存储的是数据源的key
 */
private static final ThreadLocal<String> HOLDER = new ThreadLocal<>();

/**
 * 获取当前的数据源key
 *
 * @return data source name
 */
public static String getDataSource() {
    return HOLDER.get();
}

/**
 * 设置当前数据源key
 *
 * @param dataSource the dataSource
 */
public static synchronized void setDataSource(String dataSource) {
    HOLDER.set(dataSource);
}

/**
 * 清除数据源
 */
public static void clearDataSource() {
    HOLDER.remove();
}

}

默认集团数据源和事务管理器配置

@Configuration
public class DataSourceConfig {

/**
 * 默认加载的数据源
 */
@Bean("master")
@Primary
@ConfigurationProperties(prefix = "spring.datasource.hikari")
public DataSource masterDataSource() {
    return new HikariDataSource();
}

/**
 * 注册动态数据源
 *
 * @return
 */
@Bean(name = "multiDataSource")
public MultiRouteDataSource multiDataSource() {
    MultiRouteDataSource dynamicRoutingDataSource = new MultiRouteDataSource();
    Map<Object, Object> dataSourceMap = new HashMap<>();
    dataSourceMap.put("master", masterDataSource());
    // 设置默认数据源
    dynamicRoutingDataSource.setDefaultTargetDataSource(masterDataSource());
    dynamicRoutingDataSource.setTargetDataSources(dataSourceMap);
    return dynamicRoutingDataSource;
}

/**
 * session管理器
 *
 * @return
 */
@Bean
public SqlSessionFactoryBean sqlSessionFactoryBean() {
    SqlSessionFactoryBean sqlSessionFactoryBean = new SqlSessionFactoryBean();
    // 必须将动态数据源添加到 sqlSessionFactoryBean
    sqlSessionFactoryBean.setDataSource(multiDataSource());
    return sqlSessionFactoryBean;
}

/**
 * 事务管理器
 *
 * @return the platform transaction manager
 */
@Bean
public PlatformTransactionManager transactionManager() {
    return new DataSourceTransactionManager(multiDataSource());
}

}

到目前位置,在service业务中,已经可以通过设置ThreadLocal中的数据源key来切换数据源了,测试用例

@Test
public void MultiDataSourceTest(){

DataSourceContext.setDataSource("master");
List<Foo> dataA = fooService.getAll();

DataSourceContext.setDataSource("slave");
List<Foo> dataB = fooService.getAll();

}

测试结果可以发现dataA和dataB中的数据是不同的,说明数据一个来自master库,一个来自slave库
# 5.踩坑  
此处踩坑指的是事务的坑,这里分两种
- A类 : 一个业务中只对一个数据库进行了增删改
- B类 : 一个业务中同时对多个数据库进行了增删改
## A类坑
A类解决方式比较简单,无需特殊处理,直接使用@Transactional注解,将事务管理交给spring来管理,

@Transactional(rollbackFor = Exception.class)
public void create(Foo foo) {

...

}

## B类坑
B类业务在笔者公司的业务情况如下
1. 当集团公司给子公司分配权限的时候
2. 首先在集团库(master)中的[子公司权限中间表]中新增中间表关联数据
3. 将分配的权限同步到子公司的数据库(slave)中
简略代码如下

public void assignPrivileges(){

DataSourceContext.setDataSource("master");
privilegeMapper.insert(xxx);

DataSourceContext.setDataSource("slave");
privilegeMapper.insert(xxx);

}

同时在master库和slave库中进行增删改操作,应该保证master库和slave库数据一致,需要事务支持,
但是如果使用@Transactional注解后,会发现assignPrivileges方法中DataSourceContext失效,无法切换数据源,原因是因为spring的事务管理使用了aop代理,在方法开始前,已经将当前数据源绑定在了线程中,所以无论怎样切换,使用的都是同一个数据源,spring事务管理代码截图如下
![7f6cc6b217880537e6ac39a7a99394d1.png](/upload/7f6cc6b217880537e6ac39a7a99394d1.png)
红框处spring的事务管理将数据源进行了线程绑定
## B类解决方案【踩坑 & 核心内容】
不让spring管理事务,我们自己管理事务
解决思路,每次切换数据源之后,获取到数据源的数据库连接connection,将connection的自动提交设置成false,在业务完成后,统一的对所有的connection集体提交事务,如果异常,则集体对connection进行回滚操作

> 这里我们需要切换数据源后获取到数据源的连接,故此时需要用到之前我们自己维护的map中保存的数据源,在MultiRouteDataSource中添加一个获取数据源的方法getDataSource()

public class MultiRouteDataSource extends AbstractRoutingDataSource {

public DataSource getDataSource(String key) {
    return (DataSource) targetDataSources.get(key);
}


}

上面的方案听起来是不是和spring的实现方式很像,没错,其实可以借鉴aop使用前置通知、后置通知、异常处理来做到,但是笔者公司的业务使用到的数据源的个数是不定的,所以没有使用自定义注解+aop实现来简化操作,而是使用了函数式变成的思想,做了一个函数接口来完成,读者可以通过自己的业务,使用注解+aop改造以下代码

核心代码如下

public class TransactionUtils {

@Autowired
private MultiRouteDataSource multiRouteDataSource;

/**
 * @param codes           用到的数据源的key
 * @param transactionType 事务隔离级别
 * @param executor        执行器
 */
public void execute(Collection<String> codes, TransactionTypeEnum transactionType, Executor executor) throws Exception {
    for (String code : codes) {
        DataSource dataSource = multiRouteDataSource.getDataSource(code);
        if (dataSource == null) {
            continue;
        }
        Connection connection = dataSource.getConnection();
        if (connection != null) {
            // 设置事务隔离级别
            connection.setTransactionIsolation(transactionType == null ? Connection.TRANSACTION_NONE : transactionType.getValue());
            if (connection.getAutoCommit()) {
                connection.setAutoCommit(false);
            }
        }

        // 将连接绑定到当前线程
        multiRouteDataSource.bindConnection(code, connection);
    }

    try {
        executor.invoke();
        multiRouteDataSource.doCommit();
    } catch (Exception e) {
        multiRouteDataSource.rollback();
        throw e;
    } finally {

    }
}

public void execute(Collection<String> codes, Executor executor) throws Exception {
    // 默认的事务隔离就别
    execute(codes, TransactionTypeEnum.TRANSACTION_READ_UNCOMMITTED, executor);
}

public interface Executor {
    void invoke();
}

}

> 这里使用到一个新的方法`multiRouteDataSource.bindConnection(String, DataSource)`
> 目的是将设置了非自动提交的数据库连接保存到map中,然后在`executor.invoke()`方法即具体的业务实现时,使之方法中用到的所有连接都从这个map中获取,则`executor.invoke()`中与数据库相关操作都不会自动提交,`executor.invoke()`方法结束后我们自己提交或回滚

我们先对以上使用到的MultiRouteDataSource中的方法进行扩展

public class MultiRouteDataSource extends AbstractRoutingDataSource {

...

/**
 * 保存当前线程使用了事务的数据库连接(connection)
 * 当我们自己管理事务的时候即可从此处获取到当前线程使用了哪些连接从而让这些被使用的连接commit/rollback/close
 */
private ThreadLocal<Map<String, Connection>> connectionThreadLocal = new ThreadLocal<>();

/**
 * 开启事物的时候,把连接放入 线程中,后续crud 都会拿对应的连接操作
 *
 * @param key        子公司code
 * @param connection 连接
 */
public void bindConnection(String key, Connection connection) {
    Map<String, Connection> connectionMap = connectionThreadLocal.get();
    if (connectionMap == null) {
        connectionThreadLocal.set(connectionMap);
    }
    connectionMap.put(key, connection);
}

/**
 * 提交事物
 *
 * @throws SQLException SQLException
 */
public void doCommit() throws SQLException {
    Map<String, Connection> stringConnectionMap = connectionThreadLocal.get();
    if (stringConnectionMap == null) {
        return;
    }
    for (String dataSourceName : stringConnectionMap.keySet()) {
        Connection connection = stringConnectionMap.get(dataSourceName);
        connection.commit();
        connection.close();
    }
    removeConnectionThreadLocal();
}

/**
 * 回滚事物
 *
 * @throws SQLException SQLException
 */
public void rollback() throws SQLException {
    Map<String, Connection> stringConnectionMap = connectionThreadLocal.get();
    if (stringConnectionMap == null) {
        return;
    }
    for (String dataSourceName : stringConnectionMap.keySet()) {
        Connection connection = stringConnectionMap.get(dataSourceName);
        connection.rollback();
        connection.close();
    }
    removeConnectionThreadLocal();
}

protected void removeConnectionThreadLocal() {
    connectionThreadLocal.remove();
}

...

}

### **做完这一步,对以上的代码进行测试,会发现当一个service中使用了多个数据源对数据库进行操作后,并不能正常的回滚事务,原因如下**
当事务不再由spring管理后,mybatis会使用自己的事务管理机制,即在操作完数据库后自动提交和关闭,所以解决方法就是重写Connection对象的提交和关闭方法,使mybatis的自动提交和关闭不生效

/**

  • 自己实现的Connection

  • 目的是当事务不由spring管理的时候,myabtis执行完mapper的方法会自动commit并close

  • 所以重新commit方法和close方法,让mybatis的事务控制失效,我们自己来控制事务
    */
    public class ConnectWarp implements Connection {

    private Connection connection;

    public ConnectWarp(Connection connection) {

    this.connection = connection;
    

    }

    /**

    • 当mybatis自身执行完成后调用commit方法后没有实质性的commit
      *
    • @throws SQLException SQLException
      */
      @Override
      public void commit() throws SQLException {
      // connection.commit();
      }

    /**

    • 该方法为我们自己想要提交事务的时候调用
      *
    • @throws SQLException SQLException
      */
      public void realCommit() throws SQLException {
      connection.commit();
      }

    /**

    • 当mybatis自身执行完成后调用close方法后没有实质性的close
      *
    • @throws SQLException SQLException
      */
      @Override
      public void close() throws SQLException {
      //connection.close();
      }

    /**

    • 如果包装类 要用这个方法关闭
      *
    • @throws SQLException SQLException
      */
      public void realClose() throws SQLException {
      connection.close();
      }

    /**

    • 创建一个 Statement对象,用于将SQL语句发送到数据库。
      */
      @Override
      public Statement createStatement() throws SQLException {
      return connection.createStatement();
      }

    /**

    • 创建一个 PreparedStatement对象,用于将参数化的SQL语句发送到数据库。
      *
    • @param sql 预处理sql foo:select * from foo where id = ?
      */
      @Override
      public PreparedStatement prepareStatement(String sql) throws SQLException {
      return connection.prepareStatement(sql);
      }

    /**

    • 创建一个调用数据库存储过程的 CallableStatement对象。
      *
    • @param sql
      */
      @Override
      public CallableStatement prepareCall(String sql) throws SQLException {
      return connection.prepareCall(sql);
      }

    /**

    • 将给定的SQL语句转换为系统的本机SQL语法
      *
    • @param sql
      */
      @Override
      public String nativeSQL(String sql) throws SQLException {
      return connection.nativeSQL(sql);
      }

    /**

    • 将此连接的自动提交模式设置为给定状态。
      *
    • @param autoCommit 是否自动提交
      */
      @Override
      public void setAutoCommit(boolean autoCommit) throws SQLException {
      connection.setAutoCommit(autoCommit);
      }

    /**

    • 对象的当前自动提交模式
      */
      @Override
      public boolean getAutoCommit() throws SQLException {
      return connection.getAutoCommit();
      }

    /**

    • 撤消在当前事务中所做的所有更改,并释放此 Connection对象当前持有的任何数据库锁。
      */
      @Override
      public void rollback() throws SQLException {
      connection.rollback();
      }

    /**

    • 此Connection对象是否已关闭。
      */
      @Override
      public boolean isClosed() throws SQLException {
      return connection.isClosed();
      }

    /**

    • 获取对象的数据源信息
      */
      @Override
      public DatabaseMetaData getMetaData() throws SQLException {
      return connection.getMetaData();
      }

    /**

    • 设置连接是否只读
      */
      @Override
      public void setReadOnly(boolean readOnly) throws SQLException {
      connection.setReadOnly(readOnly);
      }

    /**

    • 检索连接的只读状态
      */
      @Override
      public boolean isReadOnly() throws SQLException {
      return connection.isReadOnly();
      }

    @Override
    public void setCatalog(String catalog) throws SQLException {

    connection.setCatalog(catalog);
    

    }

    @Override
    public String getCatalog() throws SQLException {

    return connection.getCatalog();
    

    }

    /**

    • 设置事物的隔离级别
      */
      @Override
      public void setTransactionIsolation(int level) throws SQLException {
      connection.setTransactionIsolation(level);
      }

    /**

    • 获取当前连接的事物隔离级别
      */
      @Override
      public int getTransactionIsolation() throws SQLException {
      return connection.getTransactionIsolation();
      }

    /**

    • 获取连接报告中的第一个警告
      */
      @Override
      public SQLWarning getWarnings() throws SQLException {
      return connection.getWarnings();
      }

    /**

    • 清空所有连接报告
      */
      @Override
      public void clearWarnings() throws SQLException {
      connection.clearWarnings();
      }

    @Override
    public Statement createStatement(int resultSetType, int resultSetConcurrency) throws SQLException {

    return connection.createStatement(resultSetType, resultSetConcurrency);
    

    }

    @Override
    public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency) throws SQLException {

    return connection.prepareStatement(sql, resultSetType, resultSetConcurrency);
    

    }

    @Override
    public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency) throws SQLException {

    return connection.prepareCall(sql, resultSetType, resultSetConcurrency);
    

    }

    @Override
    public Map> getTypeMap() throws SQLException {

    return connection.getTypeMap();
    

    }

    @Override
    public void setTypeMap(Map> map) throws SQLException {

    connection.setTypeMap(map);
    

    }

    @Override
    public void setHoldability(int holdability) throws SQLException {

    connection.setHoldability(holdability);
    

    }

    @Override
    public int getHoldability() throws SQLException {

    return connection.getHoldability();
    

    }

    @Override
    public Savepoint setSavepoint() throws SQLException {

    return connection.setSavepoint();
    

    }

    @Override
    public Savepoint setSavepoint(String name) throws SQLException {

    return connection.setSavepoint(name);
    

    }

    @Override
    public void rollback(Savepoint savepoint) throws SQLException {

    connection.rollback(savepoint);
    

    }

    @Override
    public void releaseSavepoint(Savepoint savepoint) throws SQLException {

    connection.releaseSavepoint(savepoint);
    

    }

    @Override
    public Statement createStatement(int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException {

    return connection.createStatement(resultSetType, resultSetConcurrency, resultSetHoldability);
    

    }

    @Override
    public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException {

    return connection.prepareStatement(sql, resultSetType, resultSetConcurrency, resultSetHoldability);
    

    }

    @Override
    public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException {

    return connection.prepareCall(sql, resultSetType, resultSetConcurrency, resultSetHoldability);
    

    }

    @Override
    public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) throws SQLException {

    return connection.prepareStatement(sql, autoGeneratedKeys);
    

    }

    @Override
    public PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLException {

    return connection.prepareStatement(sql, columnIndexes);
    

    }

    @Override
    public PreparedStatement prepareStatement(String sql, String[] columnNames) throws SQLException {

    return connection.prepareStatement(sql, columnNames);
    

    }

    @Override
    public Clob createClob() throws SQLException {

    return connection.createClob();
    

    }

    @Override
    public Blob createBlob() throws SQLException {

    return connection.createBlob();
    

    }

    @Override
    public NClob createNClob() throws SQLException {

    return connection.createNClob();
    

    }

    @Override
    public SQLXML createSQLXML() throws SQLException {

    return connection.createSQLXML();
    

    }

    @Override
    public boolean isValid(int timeout) throws SQLException {

    return connection.isValid(timeout);
    

    }

    @Override
    public void setClientInfo(String name, String value) throws SQLClientInfoException {

    connection.setClientInfo(name, value);
    

    }

    @Override
    public void setClientInfo(Properties properties) throws SQLClientInfoException {

    connection.setClientInfo(properties);
    

    }

    @Override
    public String getClientInfo(String name) throws SQLException {

    return connection.getClientInfo(name);
    

    }

    @Override
    public Properties getClientInfo() throws SQLException {

    return connection.getClientInfo();
    

    }

    @Override
    public Array createArrayOf(String typeName, Object[] elements) throws SQLException {

    return connection.createArrayOf(typeName, elements);
    

    }

    @Override
    public Struct createStruct(String typeName, Object[] attributes) throws SQLException {

    return connection.createStruct(typeName, attributes);
    

    }

    @Override
    public void setSchema(String schema) throws SQLException {

    connection.setSchema(schema);
    

    }

    @Override
    public String getSchema() throws SQLException {

    return connection.getSchema();
    

    }

    @Override
    public void abort(Executor executor) throws SQLException {

    connection.abort(executor);
    

    }

    @Override
    public void setNetworkTimeout(Executor executor, int milliseconds) throws SQLException {

    connection.setNetworkTimeout(executor, milliseconds);
    

    }

    @Override
    public int getNetworkTimeout() throws SQLException {

    return connection.getNetworkTimeout();
    

    }

    @Override
    public T unwrap(Class iface) throws SQLException {

    return connection.unwrap(iface);
    

    }

    @Override
    public boolean isWrapperFor(Class iface) throws SQLException {

    return connection.isWrapperFor(iface);
    

    }
    }

bindConnection方法应该保存包装后的Connection对象,并且只有在多个数据源需要回滚的业务中才获取包装后的Connection,不去影响正常情况下mybatis的事务管理,所有还需要重写AbstractRoutingDataSource中的getConnection方法

public class MultiRouteDataSource extends AbstractRoutingDataSource {

    ...

     /**
     * mybatis在使用mapper接口执行sql的时候会从该方法获取connection执行sql
     * 如果事务是spring或者mybatis在管理,那么直接返回原生的connection
     * 如果是我们自己控制事务,则返回我们自己实现的ConnetWarp
     *
     * @return Connection
     * @throws SQLException SQLException
     */
    @Override
    public Connection getConnection() throws SQLException {
        Map<String, ConnectWarp> stringConnectionMap = connectionThreadLocal.get();
        if (stringConnectionMap == null) {
            // 没开事物 直接返回
            return determineTargetDataSource().getConnection();
        } else {
            // 开了事物 从当前线程中拿 而且拿到的是 包装过的connect 只有手动去提交和关闭连接
            String currentName = (String) determineCurrentLookupKey();
            return stringConnectionMap.get(currentName);
        }
    }   

    /**
     * 开启事物的时候,把连接放入 线程中,后续crud 都会拿对应的连接操作
     *
     * @param key        子公司code
     * @param connection 连接
     */
    public void bindConnection(String key, Connection connection) {
        Map<String, ConnectWarp> connectionMap = connectionThreadLocal.get();
        if (connectionMap == null) {
            connectionMap = new HashMap<>();
            connectionThreadLocal.set(connectionMap);
        }

        ConnectWarp connectWarp = new ConnectWarp(connection);

        connectionMap.put(key, connectWarp);
    }

    /**
     * 提交事物
     *
     * @throws SQLException SQLException
     */
    public void doCommit() throws SQLException {
        Map<String, ConnectWarp> stringConnectionMap = connectionThreadLocal.get();
        if (stringConnectionMap == null) {
            return;
        }
        for (String dataSourceName : stringConnectionMap.keySet()) {
            ConnectWarp connection = stringConnectionMap.get(dataSourceName);
            connection.realCommit();
            connection.realClose();
        }
        removeConnectionThreadLocal();
    }

    /**
     * 回滚事物
     *
     * @throws SQLException SQLException
     */
    public void rollback() throws SQLException {
        Map<String, ConnectWarp> stringConnectionMap = connectionThreadLocal.get();
        if (stringConnectionMap == null) {
            return;
        }
        for (String dataSourceName : stringConnectionMap.keySet()) {
            ConnectWarp connection = stringConnectionMap.get(dataSourceName);
            connection.rollback();
            connection.realClose();
        }
        removeConnectionThreadLocal();
    }

    ...

}

到此所有问题都已解决
笔者遇到mybatis自己管理事务的坑时,借鉴了一篇博客中的观点,博客如下:https://blog.csdn.net/u010928589/article/details/91348761