From f2ec4fad14ffe73b1a905a1ad6ce1afaa061cdf1 Mon Sep 17 00:00:00 2001 From: nikto_b Date: Sat, 19 Jul 2025 21:08:46 +0300 Subject: [PATCH] Allow default option to be overriden if was ocasionally provided on a conversion path --- src/breakshaft/models.py | 4 ++-- src/breakshaft/renderer.py | 26 +++++++++++++++++------ src/breakshaft/templates/convertor.jinja2 | 4 ++-- tests/test_default_args.py | 8 +++++++ 4 files changed, 32 insertions(+), 10 deletions(-) diff --git a/src/breakshaft/models.py b/src/breakshaft/models.py index fdd1012..0d1afb4 100644 --- a/src/breakshaft/models.py +++ b/src/breakshaft/models.py @@ -27,8 +27,8 @@ class ConversionPoint: return f'({",".join(map(str, self.requires))}) -> {self.injects.__qualname__}: {self.fn.__qualname__}' @property - def fn_args(self) -> list[type]: - return extract_func_argtypes_seq(self.fn) + def fn_args(self) -> list[tuple[str, type]]: + return extract_func_args(self.fn) @property def is_ctx_manager(self) -> bool: diff --git a/src/breakshaft/renderer.py b/src/breakshaft/renderer.py index be2b5f7..8494e6b 100644 --- a/src/breakshaft/renderer.py +++ b/src/breakshaft/renderer.py @@ -23,15 +23,24 @@ class ConversionRenderData: inj_hash: str funchash: str funcname: str - funcargs: list[str] + funcargs: list[tuple[str, str]] is_ctxmanager: bool is_async: bool @classmethod - def from_inj(cls, inj: ConversionPoint): + def from_inj(cls, inj: ConversionPoint, provided_types: set[type]): + argmap = inj.fn_args + fnargs = [] - for argtype in inj.requires: - fnargs.append(hashname(argtype)) + for arg_id, argtype in enumerate(inj.requires): + argname = argmap[arg_id][0] + fnargs.append((argname, hashname(argtype))) + + for arg_id, (argtype, _) in enumerate(inj.opt_args, len(inj.requires)): + argname = argmap[arg_id][0] + if argtype in provided_types: + fnargs.append((argname, hashname(argtype))) + return cls(hashname(inj.injects), hashname(inj.fn), repr(inj.fn), fnargs, inj.is_ctx_manager, inj.is_async) @@ -66,10 +75,15 @@ class InTimeGenerationConvertorRenderer(ConvertorRenderer): conversion_models = [] ret_hash = 0 is_async = force_async + for call_id, call in enumerate(callseq): + + provided_types = set(from_types) + for _call in callseq[:call_id]: + provided_types |= {_call.injects} + provided_types |= set(_call.requires) - for call in callseq: fnmap[hash(call.fn)] = call.fn - conv = ConversionRenderData.from_inj(call) + conv = ConversionRenderData.from_inj(call, provided_types) if conv not in conversion_models: conversion_models.append(conv) if call.is_async: diff --git a/src/breakshaft/templates/convertor.jinja2 b/src/breakshaft/templates/convertor.jinja2 index 74128ad..1bc3189 100644 --- a/src/breakshaft/templates/convertor.jinja2 +++ b/src/breakshaft/templates/convertor.jinja2 @@ -3,11 +3,11 @@ {% for conv in conversions %} {% if conv.is_ctxmanager %} {{ ' ' * ns.indent }}# {{conv.funcname}} - {{ ' ' * ns.indent }}{% if conv.is_async %}async {% endif %}with _conv_funcmap[{{ conv.funchash }}]({% for conv_arg in conv.funcargs %}_{{conv_arg}}, {% endfor %}) as _{{ conv.inj_hash }}: + {{ ' ' * ns.indent }}{% if conv.is_async %}async {% endif %}with _conv_funcmap[{{ conv.funchash }}]({% for conv_arg in conv.funcargs %}{{conv_arg[0]}} = _{{conv_arg[1]}}, {% endfor %}) as _{{ conv.inj_hash }}: {% set ns.indent = ns.indent + 1 %} {% else %} {{ ' ' * ns.indent }}# {{conv.funcname}} - {{ ' ' * ns.indent }}_{{conv.inj_hash}} = {% if conv.is_async %}await {% endif %}_conv_funcmap[{{conv.funchash}}]({% for conv_arg in conv.funcargs %}_{{conv_arg}}, {% endfor %}) + {{ ' ' * ns.indent }}_{{conv.inj_hash}} = {% if conv.is_async %}await {% endif %}_conv_funcmap[{{conv.funchash}}]({% for conv_arg in conv.funcargs %}{{conv_arg[0]}} = _{{conv_arg[1]}}, {% endfor %}) {% endif %} {% endfor %} {{ ' ' * ns.indent }}return _{{ret_hash}} diff --git a/tests/test_default_args.py b/tests/test_default_args.py index 58d5243..3c80b14 100644 --- a/tests/test_default_args.py +++ b/tests/test_default_args.py @@ -42,6 +42,10 @@ def test_default_consumer_args(): dep = fn2(123) assert dep == (123, '42') + fn3 = repo.get_conversion((int, optC), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn3(123, '1') + assert dep == (123, '1') + def test_default_inj_args(): repo = ConvRepo() @@ -68,3 +72,7 @@ def test_default_inj_args(): fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False) dep = fn2(123) assert dep == 123 + 42 + + fn3 = repo.get_conversion((int, optC,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn3(123, '0') + assert dep == 123