Allow router to raise MethodNotAllowed
This commit is contained in:
@@ -8,56 +8,57 @@ from .types import HandlerType, MethodType
|
|||||||
|
|
||||||
class Route:
|
class Route:
|
||||||
static_subroutes: dict[str, Route]
|
static_subroutes: dict[str, Route]
|
||||||
handler: Optional[HandlerType]
|
handler: dict[MethodType, HandlerType]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.static_subroutes = {}
|
self.static_subroutes = {}
|
||||||
self.handler = None
|
self.handler = {}
|
||||||
|
|
||||||
def add(self, sequence: Sequence[str], handler: HandlerType):
|
def add(self, method: MethodType, sequence: Sequence[str], handler: HandlerType):
|
||||||
if len(sequence) == 0:
|
if len(sequence) == 0:
|
||||||
self.handler = handler
|
self.handler[method] = handler
|
||||||
return
|
return
|
||||||
|
|
||||||
subroute = self.static_subroutes.get(sequence[0])
|
subroute = self.static_subroutes.get(sequence[0])
|
||||||
if subroute is None:
|
if subroute is None:
|
||||||
subroute = Route()
|
subroute = Route()
|
||||||
self.static_subroutes[sequence[0]] = subroute
|
self.static_subroutes[sequence[0]] = subroute
|
||||||
subroute.add(sequence[1:], handler)
|
|
||||||
|
|
||||||
def get(self, sequence: Sequence[str]) -> Optional[HandlerType]:
|
subroute.add(method, sequence[1:], handler)
|
||||||
|
|
||||||
|
def get(self, method: MethodType, sequence: Sequence[str]) -> Optional[HandlerType]:
|
||||||
if len(sequence) == 0:
|
if len(sequence) == 0:
|
||||||
return self.handler
|
ret = self.handler.get(method)
|
||||||
|
if ret is None:
|
||||||
|
# TODO: extract exceptions
|
||||||
|
raise ValueError('405')
|
||||||
|
return ret
|
||||||
subroute = self.static_subroutes.get(sequence[0])
|
subroute = self.static_subroutes.get(sequence[0])
|
||||||
if subroute is None:
|
if subroute is None:
|
||||||
raise ValueError('404')
|
raise ValueError('404')
|
||||||
return subroute.get(sequence[1:])
|
return subroute.get(method, sequence[1:])
|
||||||
|
|
||||||
|
|
||||||
class Router:
|
class Router:
|
||||||
_routes: dict[MethodType, Route]
|
_root: Route
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._routes = {}
|
self._root = Route()
|
||||||
|
|
||||||
def add(self, method: MethodType, path_pattern: str, handler: HandlerType):
|
def add(self, method: MethodType, path_pattern: str, handler: HandlerType):
|
||||||
assert method.upper() == method
|
assert method.upper() == method
|
||||||
|
|
||||||
root = self._routes.get(method)
|
|
||||||
if root is None:
|
|
||||||
root = Route()
|
|
||||||
self._routes[method] = root
|
|
||||||
|
|
||||||
segments = path_pattern.split('/')
|
segments = path_pattern.split('/')
|
||||||
while len(segments) > 0 and len(segments[0]) == 0:
|
while len(segments) > 0 and len(segments[0]) == 0:
|
||||||
segments = segments[1:]
|
segments = segments[1:]
|
||||||
root.add(segments, handler)
|
self._root.add(method, segments, handler)
|
||||||
|
|
||||||
def match(self, method: MethodType, path: str) -> HandlerType:
|
def match(self, method: MethodType, path: str) -> HandlerType:
|
||||||
method = typing.cast(MethodType, method.upper())
|
method = typing.cast(MethodType, method.upper())
|
||||||
root = self._routes[method]
|
|
||||||
segments = path.split('/')
|
segments = path.split('/')
|
||||||
|
|
||||||
while len(segments) > 0 and len(segments[0]) == 0:
|
while len(segments) > 0 and len(segments[0]) == 0:
|
||||||
segments = segments[1:]
|
segments = segments[1:]
|
||||||
h = root.get(segments)
|
|
||||||
|
h = self._root.get(method, segments)
|
||||||
return h
|
return h
|
||||||
|
|||||||
@@ -10,20 +10,30 @@ def test_router_root_handler():
|
|||||||
async def d(_: dict, *args):
|
async def d(_: dict, *args):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def a(_: dict, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
r = Router()
|
r = Router()
|
||||||
r.add('GET', '/', f)
|
r.add('GET', '/', f)
|
||||||
r.add('GET', '', f)
|
r.add('GET', '', f)
|
||||||
assert len(r._routes.keys()) == 1
|
assert len(r._root.static_subroutes) == 0
|
||||||
assert len(r._routes['GET'].static_subroutes.keys()) == 0
|
assert len(r._root.handler) == 1
|
||||||
assert r._routes['GET'].handler is not None
|
assert r._root.handler['GET'] == f
|
||||||
assert r._routes['GET'].handler == f
|
|
||||||
|
r.add('POST', '/', a)
|
||||||
|
assert len(r._root.static_subroutes) == 0
|
||||||
|
assert len(r._root.handler) == 2
|
||||||
|
assert r._root.handler['GET'] == f
|
||||||
|
assert r._root.handler['POST'] == a
|
||||||
|
|
||||||
r.add('GET', '/asdf', d)
|
r.add('GET', '/asdf', d)
|
||||||
assert len(r._routes.keys()) == 1
|
assert len(r._root.static_subroutes) == 1
|
||||||
assert len(r._routes['GET'].static_subroutes.keys()) == 1
|
assert len(r._root.handler) == 2
|
||||||
assert r._routes['GET'].handler is not None
|
assert r._root.handler['GET'] == f
|
||||||
assert r._routes['GET'].handler == f
|
assert r._root.handler['POST'] == a
|
||||||
assert r._routes['GET'].static_subroutes['asdf'].handler == d
|
assert len(r._root.static_subroutes['asdf'].static_subroutes) == 0
|
||||||
|
assert len(r._root.static_subroutes['asdf'].handler) == 1
|
||||||
|
assert r._root.static_subroutes['asdf'].handler['GET'] == d
|
||||||
|
|
||||||
|
|
||||||
def test_router_match():
|
def test_router_match():
|
||||||
@@ -36,5 +46,7 @@ def test_router_match():
|
|||||||
r = Router()
|
r = Router()
|
||||||
r.add('GET', 'asdf', f)
|
r.add('GET', 'asdf', f)
|
||||||
assert r.match('GET', '/asdf')
|
assert r.match('GET', '/asdf')
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match='404'):
|
||||||
r.match('GET', 'asd')
|
r.match('GET', 'asd')
|
||||||
|
with pytest.raises(ValueError, match='405'):
|
||||||
|
r.match('POST', 'asdf')
|
||||||
|
|||||||
Reference in New Issue
Block a user