Router path substitution #1

Merged
nikto_b merged 4 commits from path-substitute-feature into master 2025-07-19 04:13:44 +03:00
4 changed files with 104 additions and 17 deletions
Showing only changes of commit 75824d1893 - Show all commits

View File

@@ -88,7 +88,8 @@ class HTTPApp(ASGIApp):
sresp: SerializedResponse sresp: SerializedResponse
resp: BasicResponse resp: BasicResponse
try: try:
handler = self.router.match(method, path) matches, handler = self.router.match(method, path)
req.path_matches = matches
await handler(send, req) await handler(send, req)
return return

View File

@@ -17,6 +17,7 @@ class BasicRequest:
headers: CaseInsensitiveDict[str, str] headers: CaseInsensitiveDict[str, str]
query: dict[str, list[Any] | Any] query: dict[str, list[Any] | Any]
body: bytes body: bytes
path_matches: dict[str, str]
def __init__(self, def __init__(self,
method: MethodType, method: MethodType,
@@ -29,6 +30,7 @@ class BasicRequest:
self.headers = CaseInsensitiveDict(headers) self.headers = CaseInsensitiveDict(headers)
self.query = query self.query = query
self.body = body self.body = body
self.path_matches = {}
@classmethod @classmethod
def from_scope(cls, scope: Scope) -> BasicRequest: def from_scope(cls, scope: Scope) -> BasicRequest:

View File

@@ -1,11 +1,10 @@
from __future__ import annotations from __future__ import annotations
import re import re
from re import Pattern
import typing import typing
from typing import Optional, Sequence from typing import Optional, Sequence
from pathspec import Pattern
from .exceptions import MethodNotAllowedException, NotFoundException from .exceptions import MethodNotAllowedException, NotFoundException
from .types import InternalHandlerType from .types import InternalHandlerType
from .internal_types import MethodType from .internal_types import MethodType
@@ -13,7 +12,7 @@ from .internal_types import MethodType
class Route: class Route:
static_subroutes: dict[str, Route] static_subroutes: dict[str, Route]
regexp_subroutes: list[tuple[Pattern, Route]] regexp_subroutes: list[tuple[Pattern, list[str], Route]]
handler: dict[MethodType, InternalHandlerType] handler: dict[MethodType, InternalHandlerType]
def __init__(self) -> None: def __init__(self) -> None:
@@ -21,15 +20,15 @@ class Route:
self.regexp_subroutes = [] self.regexp_subroutes = []
self.handler = {} self.handler = {}
def _find_regexp_subroute(self, p: Pattern) -> Optional[Route]: def _find_regexp_subroute(self, p: Pattern) -> Optional[tuple[list[str], Route]]:
for _p, _r in self.regexp_subroutes: for _p, _n, _r in self.regexp_subroutes:
if _p == p: if _p == p:
return _r return _n, _r
return None return None
def _add_regexp_subroute(self, p: Pattern, r: Route): def _add_regexp_subroute(self, p: Pattern, names: list[str], r: Route):
if self._find_regexp_subroute(p) is None: if self._find_regexp_subroute(p) is None:
self.regexp_subroutes.append((p, r)) self.regexp_subroutes.append((p, names, r))
self.regexp_subroutes.sort(key=lambda x: len(x[0].pattern), reverse=True) self.regexp_subroutes.sort(key=lambda x: len(x[0].pattern), reverse=True)
def add(self, method: MethodType, sequence: Sequence[str], handler: InternalHandlerType) -> None: def add(self, method: MethodType, sequence: Sequence[str], handler: InternalHandlerType) -> None:
@@ -43,13 +42,16 @@ class Route:
if '}' not in part: if '}' not in part:
raise ValueError(f'Invalid subpath substitute placeholder: {part}') raise ValueError(f'Invalid subpath substitute placeholder: {part}')
re_part = part.replace('(', '\\(').replace(')', '\\)') re_part = part.replace('(', '\\(').replace(')', '\\)')
re_part = re.sub(r'\{.+}', r'(.+)', re_part) names = re.findall(r'\{(.*)}', part)
re_part = re.sub(r'\{.*}', r'(.*)', re_part)
re_part = re.compile('^' + re_part + '$') re_part = re.compile('^' + re_part + '$')
subroute = self._find_regexp_subroute(re_part) d = self._find_regexp_subroute(re_part)
if subroute is None: if d is None:
subroute = Route() subroute = Route()
self._add_regexp_subroute(re_part, subroute) self._add_regexp_subroute(re_part, names, subroute)
else:
_, subroute = d
else: else:
subroute = self.static_subroutes.get(part) subroute = self.static_subroutes.get(part)
if subroute is None: if subroute is None:
@@ -58,26 +60,35 @@ class Route:
subroute.add(method, sequence[1:], handler) subroute.add(method, sequence[1:], handler)
def get(self, method: MethodType, sequence: Sequence[str]) -> Optional[InternalHandlerType]: def get(self, method: MethodType, sequence: Sequence[str]) -> tuple[dict[str, str], Optional[InternalHandlerType]]:
if len(sequence) == 0: if len(sequence) == 0:
if len(self.handler) == 0: if len(self.handler) == 0:
raise NotFoundException('') raise NotFoundException('')
ret = self.handler.get(method) ret = self.handler.get(method)
if ret is None: if ret is None:
raise MethodNotAllowedException(', '.join(map(str, self.handler.keys()))) raise MethodNotAllowedException(', '.join(map(str, self.handler.keys())))
return ret return {}, ret
subroute = self.static_subroutes.get(sequence[0]) subroute = self.static_subroutes.get(sequence[0])
matches = {}
if subroute is None: if subroute is None:
for p, sr in self.regexp_subroutes: for p, subst_names, sr in self.regexp_subroutes:
m = p.match(sequence[0]) m = p.findall(sequence[0])
if m is not None: if len(m) > 0:
if len(m) != len(subst_names):
raise RuntimeError('Unable to match substitutes')
subroute = sr subroute = sr
for k, v in zip(subst_names, m):
matches[k] = v
break
if subroute is None: if subroute is None:
raise NotFoundException('/'.join(sequence)) raise NotFoundException('/'.join(sequence))
return subroute.get(method, sequence[1:]) submatches, handler = subroute.get(method, sequence[1:])
submatches |= matches
return submatches, handler
class Router: class Router:
@@ -87,14 +98,22 @@ class Router:
self._root = Route() self._root = Route()
def add(self, method: MethodType, path_pattern: str, handler: InternalHandlerType) -> None: def add(self, method: MethodType, path_pattern: str, handler: InternalHandlerType) -> None:
assert method.upper() == method method = typing.cast(MethodType, method.upper())
substitutes = re.findall(r'\{(.*)}', path_pattern)
if len(substitutes) > 0:
subst = []
for s in substitutes:
subst += list(s)
if len(subst) != len(set(subst)):
raise ValueError('Duplicate path substitute names are prohibited')
segments = path_pattern.split('/') segments = path_pattern.split('/')
while len(segments) > 0 and len(segments[0]) == 0: while len(segments) > 0 and len(segments[0]) == 0:
segments = segments[1:] segments = segments[1:]
self._root.add(method, segments, handler) self._root.add(method, segments, handler)
def match(self, method: MethodType, path: str) -> InternalHandlerType: def match(self, method: MethodType, path: str) -> tuple[dict[str, str], InternalHandlerType]:
method = typing.cast(MethodType, method.upper()) method = typing.cast(MethodType, method.upper())
segments = path.split('/') segments = path.split('/')

View File

@@ -46,7 +46,7 @@ def test_router_match():
r = Router() r = Router()
r.add('GET', 'asdf', f) r.add('GET', 'asdf', f)
assert r.match('GET', '/asdf') assert r.match('GET', '/asdf') == ({}, f)
with pytest.raises(NotFoundException, match='404\tNot Found: asd'): with pytest.raises(NotFoundException, match='404\tNot Found: asd'):
r.match('GET', 'asd') r.match('GET', 'asd')
with pytest.raises(MethodNotAllowedException, match=f'405\tMethod Not Allowed: POST /asdf, allowed: GET'): with pytest.raises(MethodNotAllowedException, match=f'405\tMethod Not Allowed: POST /asdf, allowed: GET'):
@@ -64,9 +64,9 @@ def test_router_pattern_match():
r.add('GET', '/{some}/asdf', f) r.add('GET', '/{some}/asdf', f)
r.add('GET', '/{some}/b{some1}c', f) r.add('GET', '/{some}/b{some1}c', f)
assert r.match('GET', '/1234/asdf') assert r.match('GET', '/1234/asdf') == ({'some': '1234'}, f)
assert r.match('GET', '/ /asdf') assert r.match('GET', '/ /asdf') == ({'some': ' '}, f)
assert r.match('GET', '/ /basdfc') assert r.match('GET', '/ /basdfc') == ({'some': ' ', 'some1': 'asdf'}, f)
with pytest.raises(NotFoundException, match='404\tNot Found: asd'): with pytest.raises(NotFoundException, match='404\tNot Found: asd'):
r.match('GET', 'asd') r.match('GET', 'asd')