老司机们都知道,Spring提供了一个 AbstractRoutingDataSource 可以实现数据源路由功能。但是就目前能找到的关于 Spring boot 多数据源配置大都与翟永超的《Spring boot 多数据源的配置与使用》一文介绍的方式相同。
这里我们介绍如何使用 AbstractRoutingDataSource 实现 Spring boot 多数据源。因为个人认为相较于前面提到的实现方式,基于 AbstractRoutingDataSource 的方式更优雅。
我们知道 Spring boot 引入了自动化配置,这样我们在配置数据源时,只需要引入相关jar,并在特定的.properties(.yaml)中配置相关内容,Spring boot 就能创建出我们需要的数据源。这里我们希望继续以这种方式实现我们的多数据源。
如何在 Spring boot 自动化配置的过程中植入我们需要的 AbstractRoutingDataSource,这就不得不提
BeanPostProcessor 。它是一个 factory hook,通过它我们可以实现 Bean 实例的重定义、属性设置、包装等操作。既然在 BeanPostProcessor 中能重新定义 Bean 实例,那么我们就利用这个特性在 BeanPostProcessor 中实例化一个 AbstractRoutingDataSource 实现类对象替换掉 Spring boot 自动化配置的数据源。
下面我们直接上代码,注意这里使用的 Spring boot 版本为 1.5.5.RELEASE。
spring.datasource.url=jdbc:mysql://localhost:3306/test1 spring.datasource.username=root spring.datasource.password=123 spring.datasource.names=ds1 spring.datasource.ds1.url=jdbc:mysql://localhost:3306/test2 spring.datasource.ds1.username=root spring.datasource.ds1.password=123
public class DataSourceContext { private static final ThreadLocal<String> contextHolder = new ThreadLocal<>(); public static void setDatasourceName(String datasourceName){ contextHolder.set(datasourceName); } public static String getDatasourceName(){ return contextHolder.get(); } public static void clean(){ contextHolder.remove(); } }
public class RoutingDataSource extends AbstractRoutingDataSource{ private static final Logger logger = LoggerFactory.getLogger(RoutingDataSource.class); private static final String DATASOURCE_PROPERTY_PREFIX = "spring.datasource."; private Environment environment; private Map<Object, Object> targetDataSources; @Override protected Object determineCurrentLookupKey() { String dataSourceName = DataSourceContext.getDatasourceName(); logger.info("datasource [{}].", StringUtils.hasText(dataSourceName)? dataSourceName : "master"); return dataSourceName; } @Override public void setTargetDataSources(Map<Object, Object> targetDataSources) { this.targetDataSources = new HashMap<>(targetDataSources); } @Override public void afterPropertiesSet() { buildTargetDataSources(); super.setTargetDataSources(targetDataSources); super.afterPropertiesSet(); } public void setEnvironment(Environment environment) { this.environment = environment; } private void buildTargetDataSources(){ RelaxedPropertyResolver propertyResolver = new RelaxedPropertyResolver(this.environment, DATASOURCE_PROPERTY_PREFIX); String targetDatasourceNames = propertyResolver.getProperty("names"); logger.info("target datasource names: {}", targetDatasourceNames); if(!StringUtils.hasText(targetDatasourceNames)){ return; } for (String name : targetDatasourceNames.split(",")){ Map<String, Object> subProperties = propertyResolver.getSubProperties(name + '.'); logger.info("sub properties: {}", subProperties); targetDataSources.put(name, buildDatasource(subProperties)); } } private DataSource buildDatasource(Map<String, Object> properties){ if(properties.containsKey("jndi-name")){ return buildJndiDatasource(properties.get("jndi-name").toString()); }else{ return buildJdbcDatasource(properties); } } private DataSource buildJdbcDatasource(Map<String, Object> properties){ DataSourceBuilder factory = DataSourceBuilder.create() .url(properties.get("url").toString()) .username(properties.get("username").toString()) .password(properties.get("password").toString()); return factory.build(); } private DataSource buildJndiDatasource(String datasourceName){ JndiDataSourceLookup jndiDataSourceLookup = new JndiDataSourceLookup(); return jndiDataSourceLookup.getDataSource(datasourceName); } }
@Component public class DatasourceProxyBeanProcessor implements BeanPostProcessor, EnvironmentAware { private Environment environment; @Override public void setEnvironment(Environment environment) { this.environment = environment; } @Override public Object postProcessBeforeInitialization(Object bean, String name) throws BeansException { return bean; } @Override public Object postProcessAfterInitialization(Object bean, String name) throws BeansException { if(bean instanceof DataSource){ DataSource dataSourceBean = (DataSource) bean; AbstractRoutingDataSource routingDataSource = new RoutingDataSource(); routingDataSource.setDefaultTargetDataSource(dataSourceBean); ((RoutingDataSource)routingDataSource).setEnvironment(this.environment); Map<Object, Object> targetDataSources = new HashMap<>(); targetDataSources.put("master", dataSourceBean); routingDataSource.setTargetDataSources(targetDataSources); routingDataSource.afterPropertiesSet(); return routingDataSource; } return bean; } }
通过上面三个类,我们已经可以在编码时使用DataSourceContext.setDatasourceName("ds1")和DataSourceContext.clean()实现数据源的路由。当然它现在使用起来还比较丑陋,我们可以通过切面来使得调用更优雅些。
@Retention(RetentionPolicy.RUNTIME) @Target(ElementType.METHOD) public @interface TargetDataSource { String value(); }
@Component @Aspect public class DataSourceAspect { @Around("@annotation(targetDataSource)") public Object changeDatasource(ProceedingJoinPoint pjp, TargetDataSource targetDataSource) throws Throwable { try { DataSourceContext.setDatasourceName(targetDataSource.value()); Object rtnValue = pjp.proceed(); DataSourceContext.clean(); return rtnValue; } catch (Throwable throwable) { DataSourceContext.clean(); throw throwable; } } }
下面我们通过一个测试用例,测试下是否达到了我们的要求。
@RunWith(SpringRunner.class) @SpringBootTest public class ApplicationTests { private static final Logger logger = LoggerFactory.getLogger(ApplicationTests.class); @Autowired private UserService userService; @Before public void setup(){ userService.clean(); userService.clean2(); } @Test public void contextLoads() { } @Test public void userServiceTest(){ userService.save(new User(1L, "aaa", 20)); logger.info("{}", userService.findAll()); } @Test public void changeDatasourceTest(){ userService.save(new User(1L, "aaa", 20)); DataSourceContext.setDatasourceName("ds1"); userService.save(new User(2L, "bbb", 26)); DataSourceContext.clean(); Assert.assertEquals(1, userService.findAll().size()); DataSourceContext.setDatasourceName("ds1"); Assert.assertEquals(1, userService.findAll().size()); DataSourceContext.clean(); } @Test public void datasourceAspectTest(){ userService.save(new User(1L, "aaa", 20)); userService.save2(new User(2L, "bbb", 26)); Assert.assertEquals(1, userService.findAll().size()); Assert.assertEquals(1, userService.findAll2().size()); } }
整个代码很粗,离生产还有很大的距离,需要你去完善。完整工程代码 https://github.com/loafer/spring-boot-tutorials/tree/master/springboot-multidatasource