web
[TOC]
0. 基于net/http搭建Http服务
package main
import (
"net/http"
)
func main() {
// 注册路由/hello并绑定处理函数
http.HandleFunc("/hello", func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(http.StatusOK)
writer.Write([]byte("hello world"))
})
// 启动HTTP服务并监听在端口8888,未指定Handler,使用默认mux
//DefaultServeMux is the default ServeMux used by Serve.
//var DefaultServeMux = &defaultServeMux
http.ListenAndServe(":8888", nil)
}
/*
1. 请求
curl http://127.0.0.1:8888/hello
1. 返回
hello world
2. 请求
curl http://127.0.0.1:8888/hell
2. 返回
404 page not found
*/
1. 自定义Handler Mux
package main
import (
"fmt"
"net/http"
)
// Engine 自定义mux
type Engine struct{}
func (e *Engine) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/hello":
w.WriteHeader(http.StatusOK)
w.Write([]byte("hello world"))
default:
fmt.Fprintf(w, "404 NOT FOUND: %s\n", r.URL)
}
}
func main() {
// 启动HTTP服务并监听在端口8888,使用自定义mux
http.ListenAndServe(":8888", &Engine{})
}
/*
1. 请求
curl http://127.0.0.1:8888/hello
1. 返回
hello world
2. 请求
curl http://127.0.0.1:8888/hell
2. 返回
404 NOT FOUND: /hell
*/
2. 静态路由
### package main
import (
"fmt"
"net/http"
)
func main() {
var e = New()
e.Get("/hello", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("hello world"))
})
e.Run(":8888")
}
type Method string
const (
MethodGET Method = http.MethodGet
MethodPOST Method = http.MethodPost
)
type Engine struct {
router
}
type HandlerFunc func(http.ResponseWriter, *http.Request)
func New() *Engine {
return &Engine{router: map[string]HandlerFunc{}}
}
func (e *Engine) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if handle, has := e.router[genRouterKey(Method(r.Method), r.URL.Path)]; has {
handle(w, r)
return
}
fmt.Fprintf(w, "404 NOT FOUND: %s\n", r.URL)
}
func (e *Engine) Run(addr string) error {
return http.ListenAndServe(addr, e)
}
type router map[string]HandlerFunc
func (r router) addRouter(method Method, path string, handle HandlerFunc) {
r[genRouterKey(method, path)] = handle
}
func genRouterKey(method Method, path string) string {
return string(method) + "-" + path
}
func (r router) Get(path string, handle HandlerFunc) {
r.addRouter(MethodGET, path, handle)
}
func (r router) Post(path string, handle HandlerFunc) {
r.addRouter(MethodPOST, path, handle)
}
/*
1. 请求
curl http://127.0.0.1:8888/hello
1. 返回
hello world
2. 请求
curl http://127.0.0.1:8888/hell
2. 返回
404 NOT FOUND: /hell
*/
3. 上下文 快速构造HTTP响应
package main
import (
"encoding/json"
"fmt"
"net/http"
)
func main() {
var e = New()
e.Get("/hello", func(c *Context) {
c.String(http.StatusOK, "hello world")
})
e.Run(":8888")
}
type Context struct {
// origin
r *http.Request
w http.ResponseWriter
}
func NewContext(w http.ResponseWriter, req *http.Request) *Context {
return &Context{
w: w,
r: req,
}
}
func (c *Context) PostForm(key string) string {
return c.r.FormValue(key)
}
func (c *Context) Query(key string) string {
return c.r.URL.Query().Get(key)
}
func (c *Context) Status(code int) {
c.w.WriteHeader(code)
}
func (c *Context) SetHeader(key string, value string) {
c.w.Header().Set(key, value)
}
func (c *Context) String(code int, format string, values ...interface{}) {
c.SetHeader("Content-Type", "text/plain")
c.Status(code)
c.w.Write([]byte(fmt.Sprintf(format, values...)))
}
func (c *Context) JSON(code int, obj interface{}) {
c.SetHeader("Content-Type", "application/json")
c.Status(code)
encoder := json.NewEncoder(c.w)
if err := encoder.Encode(obj); err != nil {
http.Error(c.w, err.Error(), 500)
}
}
func (c *Context) Data(code int, data []byte) {
c.Status(code)
c.w.Write(data)
}
func (c *Context) HTML(code int, html string) {
c.SetHeader("Content-Type", "text/html")
c.Status(code)
c.w.Write([]byte(html))
}
type Method string
const (
MethodGET Method = http.MethodGet
MethodPOST Method = http.MethodPost
)
type Engine struct {
router
}
type HandlerFunc func(c *Context)
func New() *Engine {
return &Engine{router: map[string]HandlerFunc{}}
}
func (e *Engine) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if handle, has := e.router[genRouterKey(Method(r.Method), r.URL.Path)]; has {
handle(NewContext(w, r))
return
}
fmt.Fprintf(w, "404 NOT FOUND: %s\n", r.URL)
}
func (e *Engine) Run(addr string) error {
return http.ListenAndServe(addr, e)
}
type router map[string]HandlerFunc
func (r router) addRouter(method Method, path string, handle HandlerFunc) {
r[genRouterKey(method, path)] = handle
}
func genRouterKey(method Method, path string) string {
return string(method) + "-" + path
}
func (r router) Get(path string, handle HandlerFunc) {
r.addRouter(MethodGET, path, handle)
}
func (r router) Post(path string, handle HandlerFunc) {
r.addRouter(MethodPOST, path, handle)
}
4. 动态路由
### package main
import (
"encoding/json"
"fmt"
"log"
"net/http"
"strings"
)
func main() {
var e = New()
e.Get("/hello", func(c *Context) {
c.String(http.StatusOK, "hello world")
})
e.Get("/hi/:name", func(c *Context) {
c.String(http.StatusOK, "hi, %s", c.params["name"])
})
e.Get("/hi/:name/info/:xx/", func(c *Context) {
c.String(http.StatusOK, "hi, %s, xx: %s", c.params["name"], c.params["xx"])
})
e.Get("/file/*filename", func(c *Context) {
c.String(http.StatusOK, "file: %s", c.params["filename"])
})
e.Run(":8888")
}
/* Context */
type Context struct {
// origin
r *http.Request
w http.ResponseWriter
// params
params map[string]string
}
func NewContext(w http.ResponseWriter, req *http.Request) *Context {
return &Context{
w: w,
r: req,
}
}
func (c *Context) PostForm(key string) string {
return c.r.FormValue(key)
}
func (c *Context) Query(key string) string {
return c.r.URL.Query().Get(key)
}
func (c *Context) Status(code int) {
c.w.WriteHeader(code)
}
func (c *Context) SetHeader(key string, value string) {
c.w.Header().Set(key, value)
}
func (c *Context) String(code int, format string, values ...interface{}) {
c.SetHeader("Content-Type", "text/plain")
c.Status(code)
c.w.Write([]byte(fmt.Sprintf(format, values...)))
}
func (c *Context) JSON(code int, obj interface{}) {
c.SetHeader("Content-Type", "application/json")
c.Status(code)
encoder := json.NewEncoder(c.w)
if err := encoder.Encode(obj); err != nil {
http.Error(c.w, err.Error(), 500)
}
}
func (c *Context) Data(code int, data []byte) {
c.Status(code)
c.w.Write(data)
}
func (c *Context) HTML(code int, html string) {
c.SetHeader("Content-Type", "text/html")
c.Status(code)
c.w.Write([]byte(html))
}
/*Engine*/
type Method string
const (
MethodGET Method = http.MethodGet
MethodPOST Method = http.MethodPost
)
type Engine struct {
router
}
type HandlerFunc func(c *Context)
func New() *Engine {
return &Engine{router: router{roots: map[Method]*node{}, handleMap: map[string]HandlerFunc{}}}
}
func (e *Engine) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if root, has := e.roots[Method(r.Method)]; has {
path, params := root.search(r.URL.Path)
if handler, ok := e.handleMap[genRouterKey(Method(r.Method), path)]; ok {
var ctx = NewContext(w, r)
ctx.params = params
handler(ctx)
}
return
}
fmt.Fprintf(w, "404 NOT FOUND: %s\n", r.URL)
}
func (e *Engine) Run(addr string) error {
return http.ListenAndServe(addr, e)
}
/* router */
type router struct {
roots map[Method]*node
handleMap map[string]HandlerFunc
}
func (r router) addRouter(method Method, path string, handle HandlerFunc) {
var (
root *node
has bool
)
if root, has = r.roots[method]; !has {
root = newTrie()
r.roots[method] = root
}
root.insert(path)
// 绑定
r.handleMap[genRouterKey(method, path)] = handle
}
func genRouterKey(method Method, path string) string {
return string(method) + "-" + path
}
func (r router) Get(path string, handle HandlerFunc) {
r.addRouter(MethodGET, path, handle)
}
func (r router) Post(path string, handle HandlerFunc) {
r.addRouter(MethodPOST, path, handle)
}
/* trie */
const (
slash = "/"
asteriskByte byte = '*'
colonByte byte = ':'
codeMe = 20210607
codeNotFound = -404
)
var (
code int32 = codeMe
pattenMap = make(map[int32]string)
)
func codeIncr() int32 {
code++
return code
}
func getCode() int32 {
return code
}
type node struct {
code int32
wild byte
children map[string]*node
param string
}
func newTrie() *node {
return &node{children: map[string]*node{}}
}
func (n *node) search(path string) (string, map[string]string) {
var (
parts = parsePattern(path)
params = make(map[string]string)
// curNode -> root
// part[0]
code = n.match(n, parts, 0, params)
)
log.Printf("pattenMap: %+v", pattenMap)
log.Printf("code: %d", code)
return pattenMap[code], params
}
func (n *node) match(curNode *node, parts []string, idx int, params map[string]string) int32 {
if idx == len(parts) {
// 所有part都匹配过了
if curNode == nil {
log.Printf("curNode is nil")
// 未匹配到规则
return codeNotFound
}
// 校验最后一位
if curNode.isWildChild(curNode.wild) {
params[curNode.param] = parts[idx-1]
}
return curNode.code
}
log.Printf("parts[i]: %s", parts[idx])
if curNode.wild == asteriskByte {
log.Printf("find params: %s, value: %s", params[curNode.param], parts[idx])
params[curNode.param] = strings.Join(parts[idx-1:], slash)
// 匹配到通配符'*'规则
return curNode.code
}
if curNode.wild == colonByte {
log.Printf("find params: %s, value: %s", params[curNode.param], parts[idx])
params[curNode.param] = parts[idx-1]
}
// 字符串完全匹配
if _, has := curNode.children[parts[idx]]; !has {
var tNode = curNode.getNode()
if tNode == nil {
return codeNotFound
}
return curNode.match(tNode, parts, idx+1, params)
}
return n.match(curNode.children[parts[idx]], parts, idx+1, params)
}
func (n *node) getNode() *node {
var (
tCode = getCode() + 1
ans *node
)
for k := range n.children {
if !n.isWildChild(n.children[k].wild) {
continue
}
if n.children[k].code < tCode {
ans = n.children[k]
tCode = ans.code
}
}
return ans
}
func (n *node) insert(pattern string) {
var (
parts = parsePattern(pattern)
curNode = n
)
log.Printf("pattern: %s, parts: %v", pattern, parts)
for i := 0; i < len(parts); i++ {
log.Printf("parts[i]: %s", parts[i])
if _, has := curNode.children[parts[i]]; !has {
curNode.children[parts[i]] = &node{children: map[string]*node{}, code: codeIncr()}
}
curNode = curNode.children[parts[i]]
if n.isWildChild(parts[i][0]) {
curNode.wild = parts[i][0]
curNode.param = parts[i][1:]
log.Printf("pattern: %s set param: %s", pattern, curNode.param)
}
}
// bind
pattenMap[curNode.code] = pattern
}
func (n *node) isWildChild(b byte) bool {
return b == colonByte || b == asteriskByte
}
func parsePattern(pattern string) []string {
var (
vs = strings.Split(pattern, slash)
parts []string
)
for _, item := range vs {
item = strings.TrimSpace(item)
if item == "" {
continue
}
parts = append(parts, item)
if item[0] == asteriskByte {
//如 /static/*filepath,可以匹配/static/fav.ico,也可以匹配/static/js/jQuery.js,
//此模式常用于静态服务器,能够递归地匹配子路径
break
}
}
return parts
}
5. 分组控制
package main
import (
"encoding/json"
"fmt"
"log"
"net/http"
"strings"
)
// 分组
// 一个Engine下可以注册多个分组
type RouterGroup struct {
// 前缀。分组名
prefix string
engine *Engine
}
func (e *Engine) Group(prefix string) *RouterGroup {
return &RouterGroup{prefix: prefix, engine: e}
}
func main() {
var e = New()
var v1 = e.Group("v1")
{
v1.Get("/hello", func(c *Context) {
c.String(http.StatusOK, "%s hello world", v1.GetPrefix())
})
v1.Get("/hi/:name", func(c *Context) {
c.String(http.StatusOK, "%s hi, %s", v1.GetPrefix(), c.params["name"])
})
v1.Get("/hi/:name/info/:xx/", func(c *Context) {
c.String(http.StatusOK, "%s hi, %s, xx: %s", v1.GetPrefix(), c.params["name"], c.params["xx"])
})
v1.Get("/file/*filename", func(c *Context) {
c.String(http.StatusOK, "%s file: %s", v1.GetPrefix(), c.params["filename"])
})
}
e.Get("/hi/:name/info/:xx/", func(c *Context) {
c.String(http.StatusOK, "hi, %s, xx: %s", c.params["name"], c.params["xx"])
})
e.Get("/file/*filename", func(c *Context) {
c.String(http.StatusOK, "file: %s", c.params["filename"])
})
e.Run(":8888")
}
/* Context */
type Context struct {
// origin
r *http.Request
w http.ResponseWriter
// params
params map[string]string
}
func NewContext(w http.ResponseWriter, req *http.Request) *Context {
return &Context{
w: w,
r: req,
}
}
func (c *Context) PostForm(key string) string {
return c.r.FormValue(key)
}
func (c *Context) Query(key string) string {
return c.r.URL.Query().Get(key)
}
func (c *Context) Status(code int) {
c.w.WriteHeader(code)
}
func (c *Context) SetHeader(key string, value string) {
c.w.Header().Set(key, value)
}
func (c *Context) String(code int, format string, values ...interface{}) {
c.SetHeader("Content-Type", "text/plain")
c.Status(code)
c.w.Write([]byte(fmt.Sprintf(format, values...)))
}
func (c *Context) JSON(code int, obj interface{}) {
c.SetHeader("Content-Type", "application/json")
c.Status(code)
encoder := json.NewEncoder(c.w)
if err := encoder.Encode(obj); err != nil {
http.Error(c.w, err.Error(), http.StatusInternalServerError)
}
}
func (c *Context) Data(code int, data []byte) {
c.Status(code)
c.w.Write(data)
}
func (c *Context) HTML(code int, html string) {
c.SetHeader("Content-Type", "text/html")
c.Status(code)
c.w.Write([]byte(html))
}
/*Engine*/
type Method string
const (
MethodGET Method = http.MethodGet
MethodPOST Method = http.MethodPost
)
type Engine struct {
router
RouterGroup
}
type HandlerFunc func(c *Context)
func New() *Engine {
var e = &Engine{router: router{roots: map[Method]*node{}, handleMap: map[string]HandlerFunc{}},
RouterGroup: RouterGroup{prefix: ""}}
e.engine = e
return e
}
func (e *Engine) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if root, has := e.roots[Method(r.Method)]; has {
path, params := root.search(r.URL.Path)
if handler, ok := e.handleMap[genRouterKey(Method(r.Method), path)]; ok {
var ctx = NewContext(w, r)
ctx.params = params
handler(ctx)
}
return
}
fmt.Fprintf(w, "404 NOT FOUND: %s\n", r.URL)
}
func (e *Engine) Run(addr string) error {
return http.ListenAndServe(addr, e)
}
/* router */
type router struct {
roots map[Method]*node
handleMap map[string]HandlerFunc
}
func (r router) addRouter(method Method, path string, handle HandlerFunc) {
var (
root *node
has bool
)
if root, has = r.roots[method]; !has {
root = newTrie()
r.roots[method] = root
}
root.insert(path)
// 绑定
r.handleMap[genRouterKey(method, path)] = handle
}
func genRouterKey(method Method, path string) string {
return string(method) + "-" + path
}
func (r *RouterGroup) Get(path string, handle HandlerFunc) {
r.engine.addRouter(MethodGET, r.prefix+path, handle)
}
func (r *RouterGroup) Post(path string, handle HandlerFunc) {
r.engine.addRouter(MethodPOST, r.prefix+path, handle)
}
func (r *RouterGroup) GetPrefix() string {
return r.prefix
}
/* trie */
const (
slash = "/"
asteriskByte byte = '*'
colonByte byte = ':'
codeMe = 20210607
codeNotFound = -404
)
var (
code int32 = codeMe
pattenMap = make(map[int32]string)
)
func codeIncr() int32 {
code++
return code
}
func getCode() int32 {
return code
}
type node struct {
code int32
wild byte
children map[string]*node
param string
}
func newTrie() *node {
return &node{children: map[string]*node{}}
}
func (n *node) search(path string) (string, map[string]string) {
var (
parts = parsePattern(path)
params = make(map[string]string)
// curNode -> root
// part[0]
code = n.match(n, parts, 0, params)
)
log.Printf("pattenMap: %+v", pattenMap)
log.Printf("code: %d", code)
return pattenMap[code], params
}
func (n *node) match(curNode *node, parts []string, idx int, params map[string]string) int32 {
if idx == len(parts) {
// 所有part都匹配过了
if curNode == nil {
log.Printf("curNode is nil")
// 未匹配到规则
return codeNotFound
}
// 校验最后一位
if curNode.isWildChild(curNode.wild) {
params[curNode.param] = parts[idx-1]
}
return curNode.code
}
log.Printf("parts[i]: %s", parts[idx])
if curNode.wild == asteriskByte {
log.Printf("find params: %s, value: %s", params[curNode.param], parts[idx])
params[curNode.param] = strings.Join(parts[idx-1:], slash)
// 匹配到通配符'*'规则
return curNode.code
}
if curNode.wild == colonByte {
log.Printf("find params: %s, value: %s", params[curNode.param], parts[idx])
params[curNode.param] = parts[idx-1]
}
// 字符串完全匹配
if _, has := curNode.children[parts[idx]]; !has {
var tNode = curNode.getNode()
if tNode == nil {
return codeNotFound
}
return curNode.match(tNode, parts, idx+1, params)
}
return n.match(curNode.children[parts[idx]], parts, idx+1, params)
}
func (n *node) getNode() *node {
var (
tCode = getCode() + 1
ans *node
)
for k := range n.children {
if !n.isWildChild(n.children[k].wild) {
continue
}
if n.children[k].code < tCode {
ans = n.children[k]
tCode = ans.code
}
}
return ans
}
func (n *node) insert(pattern string) {
var (
parts = parsePattern(pattern)
curNode = n
)
log.Printf("pattern: %s, parts: %v", pattern, parts)
for i := 0; i < len(parts); i++ {
log.Printf("parts[i]: %s", parts[i])
if _, has := curNode.children[parts[i]]; !has {
curNode.children[parts[i]] = &node{children: map[string]*node{}, code: codeIncr()}
}
curNode = curNode.children[parts[i]]
if n.isWildChild(parts[i][0]) {
curNode.wild = parts[i][0]
curNode.param = parts[i][1:]
log.Printf("pattern: %s set param: %s", pattern, curNode.param)
}
}
// bind
pattenMap[curNode.code] = pattern
}
func (n *node) isWildChild(b byte) bool {
return b == colonByte || b == asteriskByte
}
func parsePattern(pattern string) []string {
var (
vs = strings.Split(pattern, slash)
parts []string
)
for _, item := range vs {
item = strings.TrimSpace(item)
if item == "" {
continue
}
parts = append(parts, item)
if item[0] == asteriskByte {
//如 /static/*filepath,可以匹配/static/fav.ico,也可以匹配/static/js/jQuery.js,
//此模式常用于静态服务器,能够递归地匹配子路径
break
}
}
return parts
}
6. 中间件
package main
import (
"encoding/json"
"fmt"
"log"
"net/http"
"strings"
)
//http://127.0.0.1:8888/v1/file/xxx.css
//2021/06/12 15:11:20 测试 中间件demo1
//2021/06/12 15:11:20 测试 中间件demo2
//2021/06/12 15:11:20 file: xxx.css
//2021/06/12 15:11:20 测试 demo2中间件执行完毕
//2021/06/12 15:11:20 测试 demo1中间件执行完毕
func demo() HandlerFunc {
return func(c *Context) {
log.Println("测试 中间件demo1")
c.Next()
log.Println("测试 demo1中间件执行完毕")
}
}
func demo2() HandlerFunc {
return func(c *Context) {
log.Println("测试 中间件demo2")
c.Next()
log.Println("测试 demo2中间件执行完毕")
}
}
func main() {
var e = New()
var v1 = e.Group("v1").Use(demo(), demo2())
{
v1.Get("/hello", func(c *Context) {
c.String(http.StatusOK, "%s hello world", v1.GetPrefix())
})
v1.Get("/hi/:name", func(c *Context) {
c.String(http.StatusOK, "%s hi, %s", v1.GetPrefix(), c.params["name"])
})
v1.Get("/hi/:name/info/:xx/", func(c *Context) {
c.String(http.StatusOK, "%s hi, %s, xx: %s", v1.GetPrefix(), c.params["name"], c.params["xx"])
})
v1.Get("/file/*filename", func(c *Context) {
log.Printf("file: %s", c.params["filename"])
c.String(http.StatusOK, "%s file: %s", v1.GetPrefix(), c.params["filename"])
})
}
e.Get("/hi/:name/info/:xx/", func(c *Context) {
c.String(http.StatusOK, "hi, %s, xx: %s", c.params["name"], c.params["xx"])
})
e.Get("/file/*filename", func(c *Context) {
c.String(http.StatusOK, "file: %s", c.params["filename"])
})
e.Run(":8888")
}
// 分组
// 一个Engine下可以注册多个分组
var prefixMiddlewares = make(map[string][]HandlerFunc)
type RouterGroup struct {
// 前缀。分组名
prefix string
engine *Engine
}
func (r *RouterGroup) Use(fs ...HandlerFunc) *RouterGroup {
if _, has := prefixMiddlewares[r.GetPrefix()]; has {
prefixMiddlewares[r.GetPrefix()] = append(prefixMiddlewares[r.GetPrefix()], fs...)
} else {
prefixMiddlewares[r.GetPrefix()] = fs
}
return r
}
func (e *Engine) Group(prefix string) *RouterGroup {
return &RouterGroup{prefix: prefix, engine: e}
}
/* Context */
type Context struct {
// origin
r *http.Request
w http.ResponseWriter
// params
params map[string]string
// handlers
handlers []HandlerFunc
index int
}
func NewContext(w http.ResponseWriter, req *http.Request) *Context {
return &Context{
w: w,
r: req,
index: -1,
}
}
func (c *Context) SetHandlers(handlers []HandlerFunc) {
c.handlers = handlers
}
func (c *Context) Next() {
// 确保在执行函数前index指向下一个函数,否则会出现死循环
c.index++
c.handlers[c.index](c)
}
func (c *Context) PostForm(key string) string {
return c.r.FormValue(key)
}
func (c *Context) Query(key string) string {
return c.r.URL.Query().Get(key)
}
func (c *Context) Status(code int) {
c.w.WriteHeader(code)
}
func (c *Context) SetHeader(key string, value string) {
c.w.Header().Set(key, value)
}
func (c *Context) String(code int, format string, values ...interface{}) {
c.SetHeader("Content-Type", "text/plain")
c.Status(code)
c.w.Write([]byte(fmt.Sprintf(format, values...)))
}
func (c *Context) JSON(code int, obj interface{}) {
c.SetHeader("Content-Type", "application/json")
c.Status(code)
encoder := json.NewEncoder(c.w)
if err := encoder.Encode(obj); err != nil {
http.Error(c.w, err.Error(), http.StatusInternalServerError)
}
}
func (c *Context) Data(code int, data []byte) {
c.Status(code)
c.w.Write(data)
}
func (c *Context) HTML(code int, html string) {
c.SetHeader("Content-Type", "text/html")
c.Status(code)
c.w.Write([]byte(html))
}
/*Engine*/
type Method string
const (
MethodGET Method = http.MethodGet
MethodPOST Method = http.MethodPost
)
type Engine struct {
router
RouterGroup
}
type HandlerFunc func(c *Context)
func New() *Engine {
var e = &Engine{router: router{roots: map[Method]*node{}, handleMap: map[string]HandlerFunc{}},
RouterGroup: RouterGroup{prefix: ""}}
e.engine = e
return e
}
func GetMiddlewareByRouterKey(key string) []HandlerFunc {
var (
prefix = routerKeyPrefixMap[key]
ans = prefixMiddlewares[prefix]
)
return ans
}
func (e *Engine) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if root, has := e.roots[Method(r.Method)]; has {
path, params := root.search(r.URL.Path)
var routerKey = genRouterKey(Method(r.Method), path)
if handler, ok := e.handleMap[routerKey]; ok {
var ctx = NewContext(w, r)
ctx.params = params
var handlers = GetMiddlewareByRouterKey(routerKey)
if handlers != nil {
handlers = append(handlers, handler)
} else {
handlers = []HandlerFunc{handler}
}
ctx.SetHandlers(handlers)
ctx.Next()
}
return
}
fmt.Fprintf(w, "404 NOT FOUND: %s\n", r.URL)
}
func (e *Engine) Run(addr string) error {
return http.ListenAndServe(addr, e)
}
var routerKeyPrefixMap = make(map[string]string)
/* router */
type router struct {
roots map[Method]*node
handleMap map[string]HandlerFunc
}
func (r router) addRouter(method Method, prefix string, path string, handle HandlerFunc) {
var (
root *node
has bool
pattern = prefix + path
routerKey = genRouterKey(method, pattern)
)
if root, has = r.roots[method]; !has {
root = newTrie()
r.roots[method] = root
}
root.insert(pattern)
// 绑定 路径、函数
r.handleMap[routerKey] = handle
// 绑定 路径、前缀
routerKeyPrefixMap[routerKey] = prefix
}
func genRouterKey(method Method, path string) string {
return string(method) + "-" + path
}
func (r *RouterGroup) Get(path string, handle HandlerFunc) {
r.engine.addRouter(MethodGET, r.prefix, path, handle)
}
func (r *RouterGroup) Post(path string, handle HandlerFunc) {
r.engine.addRouter(MethodPOST, r.prefix, path, handle)
}
func (r *RouterGroup) GetPrefix() string {
return r.prefix
}
/* trie */
const (
slash = "/"
asteriskByte byte = '*'
colonByte byte = ':'
codeMe = 20210607
codeNotFound = -404
)
var (
code int32 = codeMe
pattenMap = make(map[int32]string)
)
func codeIncr() int32 {
code++
return code
}
func getCode() int32 {
return code
}
type node struct {
code int32
wild byte
children map[string]*node
param string
}
func newTrie() *node {
return &node{children: map[string]*node{}}
}
func (n *node) search(path string) (string, map[string]string) {
var (
parts = parsePattern(path)
params = make(map[string]string)
// curNode -> root
// part[0]
code = n.match(n, parts, 0, params)
)
log.Printf("pattenMap: %+v", pattenMap)
log.Printf("code: %d", code)
return pattenMap[code], params
}
func (n *node) match(curNode *node, parts []string, idx int, params map[string]string) int32 {
if idx == len(parts) {
// 所有part都匹配过了
if curNode == nil {
log.Printf("curNode is nil")
// 未匹配到规则
return codeNotFound
}
// 校验最后一位
if curNode.isWildChild(curNode.wild) {
params[curNode.param] = parts[idx-1]
}
return curNode.code
}
log.Printf("parts[i]: %s", parts[idx])
if curNode.wild == asteriskByte {
log.Printf("find params: %s, value: %s", params[curNode.param], parts[idx])
params[curNode.param] = strings.Join(parts[idx-1:], slash)
// 匹配到通配符'*'规则
return curNode.code
}
if curNode.wild == colonByte {
log.Printf("find params: %s, value: %s", params[curNode.param], parts[idx])
params[curNode.param] = parts[idx-1]
}
// 字符串完全匹配
if _, has := curNode.children[parts[idx]]; !has {
var tNode = curNode.getNode()
if tNode == nil {
return codeNotFound
}
return curNode.match(tNode, parts, idx+1, params)
}
return n.match(curNode.children[parts[idx]], parts, idx+1, params)
}
func (n *node) getNode() *node {
var (
tCode = getCode() + 1
ans *node
)
for k := range n.children {
if !n.isWildChild(n.children[k].wild) {
continue
}
if n.children[k].code < tCode {
ans = n.children[k]
tCode = ans.code
}
}
return ans
}
func (n *node) insert(pattern string) {
var (
parts = parsePattern(pattern)
curNode = n
)
log.Printf("pattern: %s, parts: %v", pattern, parts)
for i := 0; i < len(parts); i++ {
log.Printf("parts[i]: %s", parts[i])
if _, has := curNode.children[parts[i]]; !has {
curNode.children[parts[i]] = &node{children: map[string]*node{}, code: codeIncr()}
}
curNode = curNode.children[parts[i]]
if n.isWildChild(parts[i][0]) {
curNode.wild = parts[i][0]
curNode.param = parts[i][1:]
log.Printf("pattern: %s set param: %s", pattern, curNode.param)
}
}
// bind
pattenMap[curNode.code] = pattern
}
func (n *node) isWildChild(b byte) bool {
return b == colonByte || b == asteriskByte
}
func parsePattern(pattern string) []string {
var (
vs = strings.Split(pattern, slash)
parts []string
)
for _, item := range vs {
item = strings.TrimSpace(item)
if item == "" {
continue
}
parts = append(parts, item)
if item[0] == asteriskByte {
//如 /static/*filepath,可以匹配/static/fav.ico,也可以匹配/static/js/jQuery.js,
//此模式常用于静态服务器,能够递归地匹配子路径
break
}
}
return parts
}
最后更新于
这有帮助吗?