这个是rpc远程调用的简单demo:Consumer通过rpc远程调用Provider的服务方法sayHelloWorld(String msg),然后Provider返回""Hello World"给Consumer。
这里采用netty来实现远程通信实现rpc调用,消费者通过代理来进行远程调用远程服务。本文涉及的知识点有代理模式,jdk动态代理和netty通信。这个简单demo将服务提供者的服务注册缓存在jvm本地,后续将会考虑将服务提供者的服务注册到zookeeper注册中心。
这个简单demo将从以下四方面去进行实现,第一是公共基础层,这一层是Consumer和Provider将会共用的api和netty远程通信之间要交换的信息;第二是Provider本地注册服务的实现;第三是Provider的实现,第四是Consumer的实现。废话不多说,下面直接上代码:
package com.jinyue.common.message; import java.io.Serializable; /** * netty远程通信过程中传递的消息 */ public class RpcMessage implements Serializable { private String className; private String methodName; private Class<?>[] parameterType; private Object[] parameterValues; public RpcMessage(String className, String methodName, Class<?>[] parameterType, Object[] parameterValues) { this.className = className; this.methodName = methodName; this.parameterType = parameterType; this.parameterValues = parameterValues; } public void setClassName(String className) { this.className = className; } public void setMethodName(String methodName) { this.methodName = methodName; } public void setParameterType(Class<?>[] parameterType) { this.parameterType = parameterType; } public void setParameterValues(String parameterValue) { this.parameterValues = parameterValues; } public String getClassName() { return className; } public String getMethodName() { return methodName; } public Class<?>[] getParameterType() { return parameterType; } public Object[] getParameterValues() { return parameterValues; } } 复制代码
package com.jinyue.common.api; public interface IHelloWorld { String sayHelloWorld(String name, String content); } 复制代码
package com.jinyue.registry; import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.*; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.handler.codec.LengthFieldBasedFrameDecoder; import io.netty.handler.codec.LengthFieldPrepender; import io.netty.handler.codec.serialization.ClassResolvers; import io.netty.handler.codec.serialization.ObjectDecoder; import io.netty.handler.codec.serialization.ObjectEncoder; import org.apache.log4j.Logger; /** * 这个作为provider的提供者启动类,实质就是启动netty服务时, * 添加ProviderRegistryHandler到netty的handler处理函数中。 */ public class LocalRegistryMain { private static final Logger logger = Logger.getLogger(LocalRegistryMain.class); private static final int SERVER_PORT = 8888; public static void main(String[] args) { // 创建主从EventLoopGroup EventLoopGroup bossGroup = new NioEventLoopGroup(); EventLoopGroup workerGroup = new NioEventLoopGroup(); try { ServerBootstrap serverBootstrap = new ServerBootstrap(); // 将主从主从EventLoopGroup绑定到server上 serverBootstrap.group(bossGroup, workerGroup) .channel(NioServerSocketChannel.class) .option(ChannelOption.SO_BACKLOG, 128) .childOption(ChannelOption.SO_KEEPALIVE, true) .childHandler(new ChannelInitializer<SocketChannel>() { @Override protected void initChannel(SocketChannel ch) throws Exception { ChannelPipeline pipeline = ch.pipeline(); // 这里添加解码器和编码器,防止拆包和粘包问题 pipeline.addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)); pipeline.addLast(new LengthFieldPrepender(4)); // 这里采用jdk的序列化机制 pipeline.addLast("jdkencoder", new ObjectEncoder()); pipeline.addLast("jdkdecoder", new ObjectDecoder(Integer.MAX_VALUE, ClassResolvers.cacheDisabled(null))); // 添加自己的业务逻辑,将服务注册的handle添加到pipeline pipeline.addLast(new ProviderNettyHandler()); } }); logger.info("server start,the port is " + SERVER_PORT); // 这里同步等待future的返回,若返回失败,那么抛出异常 ChannelFuture future = serverBootstrap.bind(SERVER_PORT).sync(); // 关闭future future.channel().closeFuture().sync(); } catch (InterruptedException e) { e.printStackTrace(); } finally { // 最后记得主从group要优雅停机。 bossGroup.shutdownGracefully(); workerGroup.shutdownGracefully(); } } } 复制代码
package com.jinyue.registry; import com.jinyue.common.message.RpcMessage; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import java.lang.reflect.Method; /** * 有consumer调用时,此时ProviderNettyHandler再从ProviderRestry类的缓存实例根据传过来的接口名拿到实现类实例, * 然后再拿到实现类实例的方法,再对该方法进行反射调用,最后将调用后的结果返回给consumer即可。 */ public class ProviderNettyHandler extends ChannelInboundHandlerAdapter { /** * 当netty服务端接收到有consumer的请求时,此时将会进入到这个channelRead方法 * 此时就可以把consumer调用的参数提取出来,然后再从ProviderRestry类的缓存注册中心instanceCacheMap里 * 提取出反射实例,然后进行方法调用,再返回结果给consumer即可 * @param ctx * @param msg * @throws Exception */ @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { // 提取consumer传递过来的参数 RpcMessage rpcMessage = (RpcMessage) msg; String interfaceName = rpcMessage.getClassName(); String methodName = rpcMessage.getMethodName(); Class<?>[] parameterType = rpcMessage.getParameterType(); Object[] parameterValues = rpcMessage.getParameterValues(); // 将注册缓存instanceCacheMap的provider实例提取出来,然后进行反射调用 Object instance = ProviderLocalRegistry.getInstanceCacheMap().get(interfaceName); Method method = instance.getClass().getMethod(methodName, parameterType); Object res = method.invoke(instance, parameterValues); // 最后将结果刷到netty的输出流中返回给consumer ctx.writeAndFlush(res); ctx.close(); } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { cause.printStackTrace(); ctx.close(); } } 复制代码
package com.jinyue.registry; import org.apache.log4j.Logger; import java.io.File; import java.net.URL; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; /** * 该类主要时充当“注册中心的作用” * 将provider的服务实现类注册到本地缓存里面,采用ConcurrentHashMap【key为接口名,value为服务实例】 */ public class ProviderLocalRegistry { private static final Logger logger = Logger.getLogger(ProviderNettyHandler.class); // 服务提供者所在的包 private static final String PROVIDER_PACKAGE_NAME = "com.jinyue.provider"; // 用来装服务提供者的实例 private static Map<String, Object> instanceCacheMap = new ConcurrentHashMap<>(); // 用来存放实现类的类名 private static List<String> providerClassList = new ArrayList<>(); static { // 扫描provider包下面的实现类,并放进缓存instanceMap里面 loadProviderInstance(PROVIDER_PACKAGE_NAME); } /** * 扫描provider包下面的实现类,并放进缓存instanceMap里面 * @param packageName */ private static void loadProviderInstance(String packageName) { findProviderClass(packageName); putProviderInstance(); } /** * 找到provider包下所有的实现类名,并放进providerClassList里 */ private static void findProviderClass(final String packageName) { // 静态方法内不能用this关键字 // this.getClass().getClassLoader().getResource(PROVIDER_PACKAGE_NAME.replace("//.", "/")); // 所以得用匿名内部类来解决 // 这里由classLoader的getResource方法获得包名并封装成URL形式 URL url = new Object() { public URL getPath() { String packageDir = packageName.replace(".", "/"); URL o = this.getClass().getClassLoader().getResource(packageDir); return o; } }.getPath(); // 将该包名转换为File格式,用于以下判断是文件夹还是文件,若是文件夹则递归调用本方法, // 若不是文件夹则直接将该provider的实现类的名字放到providerClassList中 File dir = new File(url.getFile()); File[] fileArr = dir.listFiles(); for (File file : fileArr) { if (file.isDirectory()) { findProviderClass(packageName + "." + file.getName()); } else { providerClassList.add(packageName + "." + file.getName().replace(".class", "")); } } } /** * 遍历providerClassList集合的实现类,并依次将实现类的接口作为key,实现类的实例作为值放入instanceCacheMap集合中,其实这里也是模拟服务注册的过程 * 注意这里没有处理一个接口有多个实现类的情况 */ private static void putProviderInstance() { for (String providerClassName : providerClassList) { // 已经得到providerClassName,因此可以通过反射来生成实例 try { Class<?> providerClass = Class.forName(providerClassName); // 这里得到实现类的接口的全限定名作为key,因为consumer调用时是传接口的全限定名过来从缓存中获取实例再进行反射调用 String providerClassInterfaceName = providerClass.getInterfaces()[0].getName(); // 得到Provicder实现类的实例 Object instance = providerClass.newInstance(); instanceCacheMap.put(providerClassInterfaceName, instance); logger.info("注册了" + providerClassInterfaceName + "的服务"); } catch (Exception e) { e.printStackTrace(); } } } public static Map<String, Object> getInstanceCacheMap() { return instanceCacheMap; } } 复制代码
package com.jinyue.provider; import com.jinyue.common.api.IHelloWorld; /** * 服务提供者 */ public class HelloWorldImpl implements IHelloWorld { public String sayHelloWorld(String name, String content) { return name + " say:" + content; } } 复制代码
package com.jinyue.consumer; import com.jinyue.common.api.IHelloWorld; import com.jinyue.consumer.proxy.RpcProxyFactory; /** * z这个是consumer客户端测试类 */ public class ConsumerTest { public static void main(String[] args) { IHelloWorld helloWorld = (IHelloWorld)new RpcProxyFactory(IHelloWorld.class).getProxyInstance(); System.out.println(helloWorld.sayHelloWorld("jinyue", "hello world!")); } } 复制代码
package com.jinyue.consumer.proxy; import com.jinyue.consumer.request.ConsumerNettyRequest; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Method; import java.lang.reflect.Proxy; /** * 动态代理工厂类,生成调用目标接口的代理类,这个代理类实质就是在InvocationHandler的invoke方法里面调用 * netty的发送信息给服务端的相关请求方法而已,把调用目标接口类的相关信息(比如目标接口名,被调用的目标方法, * 被调用目标方法的参数类型,参数值)发送给netty服务端,netty服务端接收到请求的这些信息后,然后再从缓存map * (模拟注册中心)拿到provider的实现类,然后再利用反射进行目标方法的调用。 */ public class RpcProxyFactory { private Class<?> target; public RpcProxyFactory(Class<?> target) { this.target = target; } public Object getProxyInstance() { return Proxy.newProxyInstance(target.getClassLoader(), new Class[]{target}, new InvocationHandler() { public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { return new ConsumerNettyRequest().sendRequest(target.getName(), method.getName(), method.getParameterTypes(), args); } }); } } 复制代码
package com.jinyue.consumer.request; import com.jinyue.common.message.RpcMessage; import com.jinyue.consumer.handler.ConsumerNettyHandler; import io.netty.bootstrap.Bootstrap; import io.netty.channel.*; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.codec.LengthFieldBasedFrameDecoder; import io.netty.handler.codec.LengthFieldPrepender; import io.netty.handler.codec.serialization.ClassResolvers; import io.netty.handler.codec.serialization.ObjectDecoder; import io.netty.handler.codec.serialization.ObjectEncoder; /** * 这个类主要承担consumer对netty服务端发请请求的相关逻辑 */ public class ConsumerNettyRequest { public Object sendRequest(String interfaceName, String methodName, Class<?>[] parameterType, Object[] parameterValues) { EventLoopGroup eventLoopGroup = new NioEventLoopGroup(); ConsumerNettyHandler consumerNettyHandler = new ConsumerNettyHandler(); try { Bootstrap bootstrap = new Bootstrap(); bootstrap.group(eventLoopGroup) .channel(NioSocketChannel.class) .option(ChannelOption.TCP_NODELAY, true) .handler(new ChannelInitializer<SocketChannel>() { @Override protected void initChannel(SocketChannel ch) throws Exception { ChannelPipeline pipeline = ch.pipeline(); // 这里添加解码器和编码器,防止拆包和粘包问题 pipeline.addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)); pipeline.addLast(new LengthFieldPrepender(4)); // 这里采用jdk的序列化机制 pipeline.addLast("jdkencoder", new ObjectEncoder()); pipeline.addLast("jdkdecoder", new ObjectDecoder(Integer.MAX_VALUE, ClassResolvers.cacheDisabled(null))); // 添加自己的业务逻辑,将服务注册的handle添加到pipeline pipeline.addLast(consumerNettyHandler); } }); ChannelFuture future = bootstrap.connect("127.0.0.1", 8888).sync(); future.channel().writeAndFlush(new RpcMessage(interfaceName, methodName, parameterType, parameterValues)).sync(); future.channel().closeFuture().sync(); } catch (Exception e) { e.printStackTrace(); } finally { eventLoopGroup.shutdownGracefully(); } return consumerNettyHandler.getRes(); } } 复制代码
package com.jinyue.consumer.handler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; /** * 该类主要是客户端请求netty服务端后且当返回结果时,会回调channelRead方法接收rpc调用返回结果 */ public class ConsumerNettyHandler extends ChannelInboundHandlerAdapter { private Object res; @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { this.res = msg; } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { super.exceptionCaught(ctx, cause); } public Object getRes() { return res; } }复制代码
最后执行以下代码即运行前面的ConsumerTest类进行consumer通过netty rpc调用provider的sayHelloWorld方法进行测试:
public class ConsumerTest { public static void main(String[] args) { IHelloWorld helloWorld = (IHelloWorld)new RpcProxyFactory(IHelloWorld.class).getProxyInstance(); System.out.println(helloWorld.sayHelloWorld("jinyue", "hello world!")); } }复制代码
最终的测试结果:
项目地址:
https://github.com/jinyue233/java-demo/tree/master/netty-rpc-demo