手写SpringMVC框架
细嗅蔷薇 心有猛虎
背景: Spring 想必大家都听说过,可能现在更多流行的是Spring Boot 和Spring Cloud 框架;但是SpringMVC 作为一款 实现了 MVC 设计模式的 web (表现层) 层框架,其高开发效率和高性能也是现在很多公司仍在采用的框架;除此之外,Spring 源码大师级的代码规范和设计思想都十分值得学习;退一步说,Spring Boot 框架底层也有很多Spring 的东西,而且面试的时候还会经常被问到SpringMVC 原理,一般人可能也就是只能把SpringMVC 的运行原理背出来罢了,至于问到有没有了解其底层实现(代码层面),那很可能就歇菜了,但您要是可以手写SpringMVC 框架就肯定可以令面试官刮目相看,所以手写SpringMVC 值得一试。
在设计自己的SpringMVC 框架之前,需要了解下其运行流程。
一、SpringMVC 运行流程
图1. SpringMVC 运行流程
1、用户向服务器发送请求,请求被 Spring 前端控制器 DispatcherServlet 捕获;
2、DispatcherServlet 收到请求后调用HandlerMapping 处理器映射器;
3、处理器映射器对 请求 URL 进行解析,得到请求资源标识符( URI );然后根据该 URI, 调用 HandlerMapping 获得该 Handler 配置的所有相关的对象(包括 Handler 对象以及 Handler 对象对应的拦截 器),再以 HandlerExecutionChain 对象的形式返回给 DispatcherServlet ;
4、 DispatcherServlet 根据获得的 Handler , 通过HandlerAdapter 处理器适配器 选择一个合适的 HandlerAdapter; (附注:如果成功获得 HandlerAdapter 后,此时将开始执行拦截器的 preHandler(...) 方法);
5、 提取 Request 中的模型数据,填充 Handler 入参,开始执行 Handler (即 Controller) ;【 在填充 Handler 的入参过程中,根据你的配置, Spring 将帮你做一些额外的工作如: HttpMessageConveter: 将请求消息(如 Json 、 xml 等数据)转换成一个对象,将对象转换为指定的响应信息 ; 数据转换:对请求消息进行数据转换,如 String 转换成 Integer 、 Double 等;数据格式化: 对请求消息进行数据格式化,如将字符串转换成格式化数字或格式化日期等; 数据验证:验证数据的有效性(长度、格式等),验证结果存储到 BindingResult 或 Error 中 】
6、Controller 执行完成返回ModelAndView 对象;
7、HandlerAdapter 将controller 执行结果ModelAndView 对象返回给DispatcherServlet;
8、DispatcherServlet 将ModelAndView 对象传给ViewReslover 视图解析器;
9、ViewReslover 根据返回的 ModelAndView ,选择一个适合的 ViewResolver (必须是已经注册到 Spring 容器中的 ViewResolver) 返回给 DispatcherServlet;
10、DispatcherServlet 对View 进行渲染视图(即将模型数据填充至视图中);
11、DispatcherServlet 将渲染结果 响应用户(客户端)。
二、SpringMVC 框架设计思路
1、读取配置阶段
图2. SpringMVC 继承关系
第一步就是配置web.xml,加载自定义的DispatcherServlet。而从图中可以看出,SpringMVC 本质上是一个Servlet,这个Servlet 继承自HttpServlet,此外,FrameworkServlet 负责初始SpringMVC的容器,并将Spring 容器设置为父容器;为了读取web.xml 中的配置,需要用到ServletConfig 这个类,它代表当前Servlet 在web.xml 中的配置信息,然后通过web.xml 中加载我们自己写的MyDispatcherServlet 和读取配置文件。
2、初始化阶段
初始化阶段会在DispatcherServlet 类中,按顺序实现下面几个步骤:
1、加载配置文件;
2、扫描当前项目下的所有文件;
3、拿到扫描到的类,通过反射机制将其实例化,并且放到ioc 容器中(Map的键值对 beanName-bean) beanName默认是首字母小写;
4、初始化path 与方法的映射;
5、获取请求传入的参数并处理参数通过初始化好的handlerMapping 中拿出url 对应的方法名,反射调用。
3、运行阶段
运行阶段,每一次请求将会调用doGet 或doPost 方法,它会根据url 请求去HandlerMapping 中匹配到对应的Method,然后利用反射机制调用Controller 中的url 对应的方法,并得到结果返回。
三、实现SpringMVC 框架
首先,小老弟SpringMVC 框架只实现自己的@Controller 和@RequestMapping 注解,其它注解功能实现方式类似,实现注解较少所以项目比较简单,可以看到如下工程文件及目录截图。
图3. 工程文件及目录
1、创建Java Web 工程
创建Java Web 工程,勾选JavaEE 下方的Web Application 选项,Next。
图4. 创建Java Web 工程
2、在工程WEB-INF 下的web.xml 中加入下方配置
1 <?xml version="1.0" encoding="UTF-8"?> 2 <web-app xmlns="http://xmlns.jcp.org/xml/ns/javaee" 3 xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" 4 xsi:schemaLocation="http://xmlns.jcp.org/xml/ns/javaee http://xmlns.jcp.org/xml/ns/javaee/web-app_4_0.xsd" 5 version="4.0"> 6 7 <servlet> 8 <servlet-name>DispatcherServlet</servlet-name> 9 <servlet-class>com.tjt.springmvc.DispatcherServlet</servlet-class> 10 </servlet> 11 <servlet-mapping> 12 <servlet-name>DispatcherServlet</servlet-name> 13 <url-pattern>/</url-pattern> 14 </servlet-mapping> 15 16 </web-app>
3、创建自定义Controller 注解
1 package com.tjt.springmvc; 2 3 4 import java.lang.annotation.*; 5 6 7 /** 8 * @MyController 自定义注解类 9 * 10 * @@Target(ElementType.TYPE) 11 * 表示该注解可以作用在类上; 12 * 13 * @Retention(RetentionPolicy.RUNTIME) 14 * 表示该注解会在class 字节码文件中存在,在运行时可以通过反射获取到 15 * 16 * @Documented 17 * 标记注解,表示可以生成文档 18 */ 19 @Target(ElementType.TYPE) 20 @Retention(RetentionPolicy.RUNTIME) 21 @Documented 22 public @interface MyController { 23 24 /** 25 * public class MyController 26 * 把 class 替换成 @interface 该类即成为注解类 27 */ 28 29 /** 30 * 为Controller 注册别名 31 * @return 32 */ 33 String value() default ""; 34 35 }
4、创建自定义RequestMapping 注解
1 package com.tjt.springmvc; 2 3 4 import java.lang.annotation.*; 5 6 7 /** 8 * @MyRequestMapping 自定义注解类 9 * 10 * @Target({ElementType.METHOD,ElementType.TYPE}) 11 * 表示该注解可以作用在方法、类上; 12 * 13 * @Retention(RetentionPolicy.RUNTIME) 14 * 表示该注解会在class 字节码文件中存在,在运行时可以通过反射获取到 15 * 16 * @Documented 17 * 标记注解,表示可以生成文档 18 */ 19 @Target({ElementType.METHOD, ElementType.TYPE}) 20 @Retention(RetentionPolicy.RUNTIME) 21 @Documented 22 public @interface MyRequestMapping { 23 24 /** 25 * public @interface MyRequestMapping 26 * 把 class 替换成 @interface 该类即成为注解类 27 */ 28 29 /** 30 * 表示访问该方法的url 31 * @return 32 */ 33 String value() default ""; 34 35 }
5、设计用于获取项目工程下所有的class 文件的封装工具类
1 package com.tjt.springmvc; 2 3 4 import java.io.File; 5 import java.io.FileFilter; 6 import java.net.JarURLConnection; 7 import java.net.URL; 8 import java.net.URLDecoder; 9 import java.util.ArrayList; 10 import java.util.Enumeration; 11 import java.util.List; 12 import java.util.jar.JarEntry; 13 import java.util.jar.JarFile; 14 15 /** 16 * 从项目工程包package 中获取所有的Class 工具类 17 */ 18 public class ClassUtils { 19 20 /** 21 * 静态常量 22 */ 23 private static String FILE_CONSTANT = "file"; 24 private static String UTF8_CONSTANT = "UTF-8"; 25 private static String JAR_CONSTANT = "jar"; 26 private static String POINT_CLASS_CONSTANT = ".class"; 27 private static char POINT_CONSTANT = '.'; 28 private static char LEFT_LINE_CONSTANT = '/'; 29 30 31 /** 32 * 定义私有构造函数来屏蔽隐式公有构造函数 33 */ 34 private ClassUtils() { 35 } 36 37 38 /** 39 * 从项目工程包package 中获取所有的Class 40 * getClasses 41 * 42 * @param packageName 43 * @return 44 */ 45 public static List<Class<?>> getClasses(String packageName) throws Exception { 46 47 48 List<Class<?>> classes = new ArrayList<Class<?>>(); // 定义一个class 类的泛型集合 49 boolean recursive = true; // recursive 是否循环迭代 50 String packageDirName = packageName.replace(POINT_CONSTANT, LEFT_LINE_CONSTANT); // 获取包的名字 并进行替换 51 Enumeration<URL> dirs; // 定义一个枚举的集合 分别保存该目录下的所有java 类文件及Jar 包等内容 52 dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName); 53 /** 54 * 循环迭代 处理这个目录下的things 55 */ 56 while (dirs.hasMoreElements()) { 57 URL url = dirs.nextElement(); // 获取下一个元素 58 String protocol = url.getProtocol(); // 得到协议的名称 protocol 59 // 如果是 60 /** 61 * 若protocol 是文件形式 62 */ 63 if (FILE_CONSTANT.equals(protocol)) { 64 String filePath = URLDecoder.decode(url.getFile(), UTF8_CONSTANT); // 获取包的物理路径 65 findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes); // 以文件的方式扫描整个包下的文件 并添加到集合中 66 /** 67 * 若protocol 是jar 包文件 68 */ 69 } else if (JAR_CONSTANT.equals(protocol)) { 70 JarFile jar; // 定义一个JarFile 71 jar = ((JarURLConnection) url.openConnection()).getJarFile(); // 获取jar 72 Enumeration<JarEntry> entries = jar.entries(); // 从jar 包中获取枚举类 73 /** 74 * 循环迭代从Jar 包中获得的枚举类 75 */ 76 while (entries.hasMoreElements()) { 77 JarEntry entry = entries.nextElement(); // 获取jar里的一个实体,如目录、META-INF等文件 78 String name = entry.getName(); 79 /** 80 * 若实体名是以 / 开头 81 */ 82 if (name.charAt(0) == LEFT_LINE_CONSTANT) { 83 name = name.substring(1); // 获取后面的字符串 84 } 85 // 如果 86 /** 87 * 若实体名前半部分和定义的包名相同 88 */ 89 if (name.startsWith(packageDirName)) { 90 int idx = name.lastIndexOf(LEFT_LINE_CONSTANT); 91 /** 92 * 并且实体名以为'/' 结尾 93 * 若其以'/' 结尾则是一个包 94 */ 95 if (idx != -1) { 96 packageName = name.substring(0, idx).replace(LEFT_LINE_CONSTANT, POINT_CONSTANT); // 获取包名 并把'/' 替换成'.' 97 } 98 /** 99 * 若实体是一个包 且可以继续迭代 100 */ 101 if ((idx != -1) || recursive) { 102 if (name.endsWith(POINT_CLASS_CONSTANT) && !entry.isDirectory()) { // 若为.class 文件 且不是目录 103 String className = name.substring(packageName.length() + 1, name.length() - 6); // 则去掉.class 后缀并获取真正的类名 104 classes.add(Class.forName(packageName + '.' + className)); // 把获得到的类名添加到classes 105 } 106 } 107 } 108 } 109 } 110 } 111 112 return classes; 113 } 114 115 116 /** 117 * 以文件的形式来获取包下的所有Class 118 * findAndAddClassesInPackageByFile 119 * 120 * @param packageName 121 * @param packagePath 122 * @param recursive 123 * @param classes 124 */ 125 public static void findAndAddClassesInPackageByFile( 126 String packageName, String packagePath, 127 final boolean recursive, 128 List<Class<?>> classes) throws Exception { 129 130 131 File dir = new File(packagePath); // 获取此包的目录并建立一个File 132 133 if (!dir.exists() || !dir.isDirectory()) { // 若dir 不存在或者 也不是目录就直接返回 134 return; 135 } 136 137 File[] dirfiles = dir.listFiles(new FileFilter() { // 若dir 存在 则获取包下的所有文件、目录 138 139 /** 140 * 自定义过滤规则 如果可以循环(包含子目录) 或则是以.class 结尾的文件(编译好的java 字节码文件) 141 * @param file 142 * @return 143 */ 144 @Override 145 public boolean accept(File file) { 146 return (recursive && file.isDirectory()) || (file.getName().endsWith(POINT_CLASS_CONSTANT)); 147 } 148 }); 149 150 /** 151 * 循环所有文件获取java 类文件并添加到集合中 152 */ 153 for (File file : dirfiles) { 154 if (file.isDirectory()) { // 若file 为目录 则继续扫描 155 findAndAddClassesInPackageByFile(packageName + "." + file.getName(), file.getAbsolutePath(), recursive, 156 classes); 157 } else { // 若file 为java 类文件 则去掉后面的.class 只留下类名 158 String className = file.getName().substring(0, file.getName().length() - 6); 159 classes.add(Class.forName(packageName + '.' + className)); // 把className 添加到集合中去 160 161 } 162 } 163 } 164 }
6、访问跳转页面index.jsp
1 <%-- 2 Created by IntelliJ IDEA. 3 User: apple 4 Date: 2019-11-07 5 Time: 13:28 6 To change this template use File | Settings | File Templates. 7 --%> 8 <%-- 9 <%@ page contentType="text/html;charset=UTF-8" language="java" %> 10 --%> 11 <html> 12 <head> 13 <title>My Fucking SpringMVC</title> 14 </head> 15 <body> 16 <h2>The Lie We Live!</h2> 17 <H2>My Fucking SpringMVC</H2> 18 </body> 19 </html>
7、自定义DispatcherServlet 设计,继承HttpServlet,重写init 方法、doGet、doPost 等方法,以及自定义注解要实现的功能。
1 package com.tjt.springmvc; 2 3 4 import javax.servlet.ServletConfig; 5 import javax.servlet.ServletException; 6 import javax.servlet.http.HttpServlet; 7 import javax.servlet.http.HttpServletRequest; 8 import javax.servlet.http.HttpServletResponse; 9 import java.io.IOException; 10 import java.lang.reflect.InvocationTargetException; 11 import java.lang.reflect.Method; 12 import java.util.List; 13 import java.util.Map; 14 import java.util.Objects; 15 import java.util.concurrent.ConcurrentHashMap; 16 17 18 19 /** 20 * DispatcherServlet 处理SpringMVC 框架流程 21 * 主要流程: 22 * 1、包扫描获取包下面所有的类 23 * 2、初始化包下面所有的类 24 * 3、初始化HandlerMapping 方法,将url 和方法对应上 25 * 4、实现HttpServlet 重写doPost 方法 26 * 27 */ 28 public class DispatcherServlet extends HttpServlet { 29 30 /** 31 * 部分静态常量 32 */ 33 private static String PACKAGE_CLASS_NULL_EX = "包扫描后的classes为null"; 34 private static String HTTP_NOT_EXIST = "sorry http is not exit 404"; 35 private static String METHOD_NOT_EXIST = "sorry method is not exit 404"; 36 private static String POINT_JSP = ".jsp"; 37 private static String LEFT_LINE = "/"; 38 39 /** 40 * 用于存放SpringMVC bean 的容器 41 */ 42 private ConcurrentHashMap<String, Object> mvcBeans = new ConcurrentHashMap<>(); 43 private ConcurrentHashMap<String, Object> mvcBeanUrl = new ConcurrentHashMap<>(); 44 private ConcurrentHashMap<String, String> mvcMethodUrl = new ConcurrentHashMap<>(); 45 private static String PROJECT_PACKAGE_PATH = "com.tjt.springmvc"; 46 47 48 /** 49 * 按顺序初始化组件 50 * @param config 51 */ 52 @Override 53 public void init(ServletConfig config) { 54 String packagePath = PROJECT_PACKAGE_PATH; 55 try { 56 //1.进行报扫描获取当前包下面所有的类 57 List<Class<?>> classes = comscanPackage(packagePath); 58 //2.初始化springmvcbean 59 initSpringMvcBean(classes); 60 } catch (Exception e) { 61 e.printStackTrace(); 62 } 63 //3.将请求地址和方法进行映射 64 initHandMapping(mvcBeans); 65 } 66 67 68 /** 69 * 调用ClassUtils 工具类获取工程中所有的class 70 * @param packagePath 71 * @return 72 * @throws Exception 73 */ 74 public List<Class<?>> comscanPackage(String packagePath) throws Exception { 75 List<Class<?>> classes = ClassUtils.getClasses(packagePath); 76 return classes; 77 } 78 79 /** 80 * 初始化SpringMVC bean 81 * 82 * @param classes 83 * @throws Exception 84 */ 85 public void initSpringMvcBean(List<Class<?>> classes) throws Exception { 86 /** 87 * 若包扫描出的classes 为空则直接抛异常 88 */ 89 if (classes.isEmpty()) { 90 throw new Exception(PACKAGE_CLASS_NULL_EX); 91 } 92 93 /** 94 * 遍历所有classes 获取@MyController 注解 95 */ 96 for (Class<?> aClass : classes) { 97 //获取被自定义注解的controller 将其初始化到自定义springmvc 容器中 98 MyController declaredAnnotation = aClass.getDeclaredAnnotation(MyController.class); 99 if (declaredAnnotation != null) { 100 //获取类的名字 101 String beanid = lowerFirstCapse(aClass.getSimpleName()); 102 //获取对象 103 Object beanObj = aClass.newInstance(); 104 //放入spring 容器 105 mvcBeans.put(beanid, beanObj); 106 } 107 } 108 109 } 110 111 /** 112 * 初始化HandlerMapping 方法 113 * 114 * @param mvcBeans 115 */ 116 public void initHandMapping(ConcurrentHashMap<String, Object> mvcBeans) { 117 /** 118 * 遍历springmvc 获取注入的对象值 119 */ 120 for (Map.Entry<String, Object> entry : mvcBeans.entrySet()) { 121 Object objValue = entry.getValue(); 122 Class<?> aClass = objValue.getClass(); 123 //获取当前类 判断其是否有自定义的requestMapping 注解 124 String mappingUrl = null; 125 MyRequestMapping anRequestMapping = aClass.getDeclaredAnnotation(MyRequestMapping.class); 126 if (anRequestMapping != null) { 127 mappingUrl = anRequestMapping.value(); 128 } 129 //获取当前类所有方法,判断方法上是否有注解 130 Method[] declaredMethods = aClass.getDeclaredMethods(); 131 /** 132 * 遍历注解 133 */ 134 for (Method method : declaredMethods) { 135 MyRequestMapping methodDeclaredAnnotation = method.getDeclaredAnnotation(MyRequestMapping.class); 136 if (methodDeclaredAnnotation != null) { 137 String methodUrl = methodDeclaredAnnotation.value(); 138 mvcBeanUrl.put(mappingUrl + methodUrl, objValue); 139 mvcMethodUrl.put(mappingUrl + methodUrl, method.getName()); 140 } 141 } 142 143 } 144 145 } 146 147 /** 148 * @param str 149 * @return 类名首字母小写 150 */ 151 public static String lowerFirstCapse(String str) { 152 char[] chars = str.toCharArray(); 153 chars[0] += 32; 154 return String.valueOf(chars); 155 156 } 157 158 /** 159 * doPost 请求 160 * @param req 161 * @param resp 162 * @throws ServletException 163 * @throws IOException 164 */ 165 @Override 166 protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { 167 try { 168 /** 169 * 处理请求 170 */ 171 doServelt(req, resp); 172 } catch (NoSuchMethodException e) { 173 e.printStackTrace(); 174 } catch (InvocationTargetException e) { 175 e.printStackTrace(); 176 } catch (IllegalAccessException e) { 177 e.printStackTrace(); 178 } 179 } 180 181 /** 182 * doServelt 处理请求 183 * @param req 184 * @param resp 185 * @throws IOException 186 * @throws NoSuchMethodException 187 * @throws InvocationTargetException 188 * @throws IllegalAccessException 189 * @throws ServletException 190 */ 191 private void doServelt(HttpServletRequest req, HttpServletResponse resp) throws IOException, NoSuchMethodException, InvocationTargetException, IllegalAccessException, ServletException { 192 //获取请求地址 193 String requestUrl = req.getRequestURI(); 194 //查找地址所对应bean 195 Object object = mvcBeanUrl.get(requestUrl); 196 if (Objects.isNull(object)) { 197 resp.getWriter().println(HTTP_NOT_EXIST); 198 return; 199 } 200 //获取请求的方法 201 String methodName = mvcMethodUrl.get(requestUrl); 202 if (methodName == null) { 203 resp.getWriter().println(METHOD_NOT_EXIST); 204 return; 205 } 206 207 208 //通过构反射执行方法 209 Class<?> aClass = object.getClass(); 210 Method method = aClass.getMethod(methodName); 211 212 String invoke = (String) method.invoke(object); 213 // 获取后缀信息 214 String suffix = POINT_JSP; 215 // 页面目录地址 216 String prefix = LEFT_LINE; 217 req.getRequestDispatcher(prefix + invoke + suffix).forward(req, resp); 218 219 220 221 222 } 223 224 /** 225 * doGet 请求 226 * @param req 227 * @param resp 228 * @throws ServletException 229 * @throws IOException 230 */ 231 @Override 232 protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { 233 this.doPost(req, resp); 234 } 235 236 237 }
8、测试手写SpringMVC 框架效果类TestMySpringMVC 。
1 package com.tjt.springmvc; 2 3 4 /** 5 * 手写SpringMVC 测试类 6 * TestMySpringMVC 7 */ 8 @MyController 9 @MyRequestMapping(value = "/tjt") 10 public class TestMySpringMVC { 11 12 13 /** 14 * 测试手写SpringMVC 框架效果 testMyMVC1 15 * @return 16 */ 17 @MyRequestMapping("/mvc") 18 public String testMyMVC1() { 19 System.out.println("he Lie We Live!"); 20 return "index"; 21 } 22 23 24 }
9、配置Tomcat 用于运行Web 项目。
图5. 配置tomcat
10、运行项目,访问测试。
1、输入正常路径 http://localhost:8080/tjt/mvc 访问测试效果如下:
图6. 正常路径测试效果
2、输入非法(不存在)路径 http://localhost:8080/tjt/mvc8 访问测试效果如下:
图7. 非法路径测试效果
3、控制台打印“The Lie We Live”如下:
图8. 控制台打印
测试效果如上则证明成功手写SpringMVC 框架,恭喜。