From a33914eceeaf0f155cece9357baf5d61e0e96f65 Mon Sep 17 00:00:00 2001 From: Tomer Haskalovitch Date: Sun, 2 Feb 2025 12:25:32 +0200 Subject: [PATCH] mgr/dashboard: introduce "convert_to_model" util fixes: https://tracker.ceph.com/issues/62705 Signed-off-by: Tomer Haskalovitch --- .../mgr/dashboard/controllers/nvmeof.py | 64 ++- src/pybind/mgr/dashboard/model/nvmeof.py | 40 +- .../mgr/dashboard/services/nvmeof_client.py | 124 ++++- .../mgr/dashboard/tests/test_nvmeof_client.py | 459 ++++++++++++++++++ 4 files changed, 655 insertions(+), 32 deletions(-) create mode 100644 src/pybind/mgr/dashboard/tests/test_nvmeof_client.py diff --git a/src/pybind/mgr/dashboard/controllers/nvmeof.py b/src/pybind/mgr/dashboard/controllers/nvmeof.py index 3ca1775349aaf..e32de917c04af 100644 --- a/src/pybind/mgr/dashboard/controllers/nvmeof.py +++ b/src/pybind/mgr/dashboard/controllers/nvmeof.py @@ -23,8 +23,8 @@ NVME_SCHEMA = { } try: - from ..services.nvmeof_client import NVMeoFClient, empty_response, \ - handle_nvmeof_error, map_collection, map_model + from ..services.nvmeof_client import NVMeoFClient, convert_to_model, \ + empty_response, handle_nvmeof_error, map_collection, pick except ImportError as e: logger.error("Failed to import NVMeoFClient and related components: %s", e) else: @@ -33,7 +33,7 @@ else: class NVMeoFGateway(RESTController): @EndpointDoc("Get information about the NVMeoF gateway") @NvmeofCLICommand("nvmeof gw info") - @map_model(model.GatewayInfo) + @convert_to_model(model.GatewayInfo) @handle_nvmeof_error def list(self, gw_group: Optional[str] = None): return NVMeoFClient(gw_group=gw_group).stub.get_gateway_info( @@ -56,7 +56,7 @@ else: @ReadPermission @Endpoint('GET', '/version') @NvmeofCLICommand("nvmeof gw version") - @map_model(model.GatewayVersion) + @convert_to_model(model.GatewayVersion) @handle_nvmeof_error def version(self, gw_group: Optional[str] = None): gw_info = NVMeoFClient(gw_group=gw_group).stub.get_gateway_info( @@ -69,7 +69,7 @@ else: @ReadPermission @Endpoint('GET', '/log_level') @NvmeofCLICommand("nvmeof gw get_log_level") - @map_model(model.GatewayLogLevelInfo) + @convert_to_model(model.GatewayLogLevelInfo) @handle_nvmeof_error def get_log_level(self, gw_group: Optional[str] = None): gw_log_level = NVMeoFClient(gw_group=gw_group).stub.get_gateway_log_level( @@ -80,7 +80,7 @@ else: @ReadPermission @Endpoint('PUT', '/log_level') @NvmeofCLICommand("nvmeof gw set_log_level") - @map_model(model.RequestStatus) + @convert_to_model(model.RequestStatus) @handle_nvmeof_error def set_log_level(self, log_level: str, gw_group: Optional[str] = None): log_level = log_level.lower() @@ -95,7 +95,7 @@ else: @ReadPermission @Endpoint('GET', '/log_level') @NvmeofCLICommand("nvmeof spdk_log_level get") - @map_model(model.SpdkNvmfLogFlagsAndLevelInfo) + @convert_to_model(model.SpdkNvmfLogFlagsAndLevelInfo) @handle_nvmeof_error def get_spdk_log_level(self, gw_group: Optional[str] = None): spdk_log_level = NVMeoFClient(gw_group=gw_group).stub.get_spdk_nvmf_log_flags_and_level( @@ -106,7 +106,7 @@ else: @ReadPermission @Endpoint('PUT', '/log_level') @NvmeofCLICommand("nvmeof spdk_log_level set") - @map_model(model.RequestStatus) + @convert_to_model(model.RequestStatus) @handle_nvmeof_error def set_spdk_log_level(self, log_level: Optional[str] = None, print_level: Optional[str] = None, gw_group: Optional[str] = None): @@ -121,7 +121,7 @@ else: @ReadPermission @Endpoint('PUT', '/log_level/disable') @NvmeofCLICommand("nvmeof spdk_log_level disable") - @map_model(model.RequestStatus) + @convert_to_model(model.RequestStatus) @handle_nvmeof_error def disable_spdk_log_level(self, gw_group: Optional[str] = None): spdk_log_level = NVMeoFClient(gw_group=gw_group).stub.disable_spdk_nvmf_logs( @@ -133,8 +133,9 @@ else: @APIDoc("NVMe-oF Subsystem Management API", "NVMe-oF Subsystem") class NVMeoFSubsystem(RESTController): @EndpointDoc("List all NVMeoF subsystems") + @pick(field="subsystems", first=True) @NvmeofCLICommand("nvmeof subsystem list") - @map_collection(model.Subsystem, pick="subsystems") + @convert_to_model(model.SubsystemList) @handle_nvmeof_error def list(self, gw_group: Optional[str] = None): return NVMeoFClient(gw_group=gw_group).stub.list_subsystems( @@ -148,8 +149,9 @@ else: "gw_group": Param(str, "NVMeoF gateway group", True, None), }, ) + @pick(field="subsystems", first=True) @NvmeofCLICommand("nvmeof subsystem get") - @map_model(model.Subsystem, first="subsystems") + @convert_to_model(model.SubsystemList) @handle_nvmeof_error def get(self, nqn: str, gw_group: Optional[str] = None): return NVMeoFClient(gw_group=gw_group).stub.list_subsystems( @@ -165,8 +167,9 @@ else: "gw_group": Param(str, "NVMeoF gateway group", True, None), }, ) - @NvmeofCLICommand("nvmeof subsystem add") @empty_response + @NvmeofCLICommand("nvmeof subsystem add") + @convert_to_model(model.RequestStatus) @handle_nvmeof_error def create(self, nqn: str, enable_ha: bool, max_namespaces: int = 1024, gw_group: Optional[str] = None): @@ -184,8 +187,9 @@ else: "gw_group": Param(str, "NVMeoF gateway group", True, None), }, ) - @NvmeofCLICommand("nvmeof subsystem del") @empty_response + @NvmeofCLICommand("nvmeof subsystem del") + @convert_to_model(model.RequestStatus) @handle_nvmeof_error def delete(self, nqn: str, force: Optional[str] = "false", gw_group: Optional[str] = None): return NVMeoFClient(gw_group=gw_group).stub.delete_subsystem( @@ -204,8 +208,9 @@ else: "gw_group": Param(str, "NVMeoF gateway group", True, None), }, ) + @pick("listeners") @NvmeofCLICommand("nvmeof listener list") - @map_collection(model.Listener, pick="listeners") + @convert_to_model(model.ListenerList) @handle_nvmeof_error def list(self, nqn: str, gw_group: Optional[str] = None): return NVMeoFClient(gw_group=gw_group).stub.list_listeners( @@ -223,8 +228,9 @@ else: "gw_group": Param(str, "NVMeoF gateway group", True, None), }, ) - @NvmeofCLICommand("nvmeof listener add") @empty_response + @NvmeofCLICommand("nvmeof listener add") + @convert_to_model(model.RequestStatus) @handle_nvmeof_error def create( self, @@ -256,8 +262,9 @@ else: "gw_group": Param(str, "NVMeoF gateway group", True, None), }, ) - @NvmeofCLICommand("nvmeof listener del") @empty_response + @NvmeofCLICommand("nvmeof listener del") + @convert_to_model(model.RequestStatus) @handle_nvmeof_error def delete( self, @@ -290,8 +297,9 @@ else: "gw_group": Param(str, "NVMeoF gateway group", True, None), }, ) + @pick("namespaces") @NvmeofCLICommand("nvmeof ns list") - @map_collection(model.Namespace, pick="namespaces") + @convert_to_model(model.NamespaceList) @handle_nvmeof_error def list(self, nqn: str, gw_group: Optional[str] = None): return NVMeoFClient(gw_group=gw_group).stub.list_namespaces( @@ -306,8 +314,9 @@ else: "gw_group": Param(str, "NVMeoF gateway group", True, None), }, ) + @pick("namespaces", first=True) @NvmeofCLICommand("nvmeof ns get") - @map_model(model.Namespace, first="namespaces") + @convert_to_model(model.NamespaceList) @handle_nvmeof_error def get(self, nqn: str, nsid: str, gw_group: Optional[str] = None): return NVMeoFClient(gw_group=gw_group).stub.list_namespaces( @@ -325,7 +334,7 @@ else: }, ) @NvmeofCLICommand("nvmeof ns get_io_stats") - @map_model(model.NamespaceIOStats) + @convert_to_model(model.NamespaceIOStats) @handle_nvmeof_error def io_stats(self, nqn: str, nsid: str, gw_group: Optional[str] = None): return NVMeoFClient(gw_group=gw_group).stub.namespace_get_io_stats( @@ -357,7 +366,7 @@ else: }, ) @NvmeofCLICommand("nvmeof ns add") - @map_model(model.NamespaceCreation) + @convert_to_model(model.NamespaceCreation) @handle_nvmeof_error def create( self, @@ -404,8 +413,9 @@ else: "trash_image": Param(bool, "Trash RBD image after removing namespace") }, ) + @pick("namespaces", first=True) @NvmeofCLICommand("nvmeof ns update") - @map_model(model.Namespace, first="namespaces") + @convert_to_model(model.NamespaceList) @handle_nvmeof_error def update( self, @@ -486,8 +496,9 @@ else: "force": Param(str, "Force remove the RBD image") }, ) - @NvmeofCLICommand("nvmeof ns del") @empty_response + @NvmeofCLICommand("nvmeof ns del") + @convert_to_model(model.RequestStatus) @handle_nvmeof_error def delete( self, @@ -538,8 +549,9 @@ else: "gw_group": Param(str, "NVMeoF gateway group", True, None), }, ) - @NvmeofCLICommand("nvmeof host add") @empty_response + @NvmeofCLICommand("nvmeof host add") + @convert_to_model(model.RequestStatus) @handle_nvmeof_error def create(self, nqn: str, host_nqn: str, gw_group: Optional[str] = None): return NVMeoFClient(gw_group=gw_group).stub.add_host( @@ -554,8 +566,9 @@ else: "gw_group": Param(str, "NVMeoF gateway group", True, None), }, ) - @NvmeofCLICommand("nvmeof host del") @empty_response + @NvmeofCLICommand("nvmeof host del") + @convert_to_model(model.RequestStatus) @handle_nvmeof_error def delete(self, nqn: str, host_nqn: str, gw_group: Optional[str] = None): return NVMeoFClient(gw_group=gw_group).stub.remove_host( @@ -572,8 +585,9 @@ else: "gw_group": Param(str, "NVMeoF gateway group", True, None), }, ) + @pick("connections") @NvmeofCLICommand("nvmeof connection list") - @map_collection(model.Connection, pick="connections") + @convert_to_model(model.ConnectionList) @handle_nvmeof_error def list(self, nqn: str, gw_group: Optional[str] = None): return NVMeoFClient(gw_group=gw_group).stub.list_connections( diff --git a/src/pybind/mgr/dashboard/model/nvmeof.py b/src/pybind/mgr/dashboard/model/nvmeof.py index 7661bae305dd8..49b175a02e1ec 100644 --- a/src/pybind/mgr/dashboard/model/nvmeof.py +++ b/src/pybind/mgr/dashboard/model/nvmeof.py @@ -1,4 +1,4 @@ -from typing import NamedTuple, Optional +from typing import List, NamedTuple, Optional class GatewayInfo(NamedTuple): @@ -19,14 +19,14 @@ class GatewayVersion(NamedTuple): class GatewayLogLevelInfo(NamedTuple): status: int error_message: str - log_level: int + log_level: str class SpdkNvmfLogFlagsAndLevelInfo(NamedTuple): status: int error_message: str - log_level: int - log_print_level: int + log_level: str + log_print_level: str class Subsystem(NamedTuple): @@ -41,6 +41,12 @@ class Subsystem(NamedTuple): max_namespaces: int +class SubsystemList(NamedTuple): + status: int + error_message: str + subsystems: List[Subsystem] + + class Connection(NamedTuple): traddr: str trsvcid: int @@ -51,7 +57,16 @@ class Connection(NamedTuple): controller_id: int +class ConnectionList(NamedTuple): + status: int + error_message: str + subsystem_nqn: str + connections: List[Connection] + + class NamespaceCreation(NamedTuple): + status: int + error_message: str nsid: int @@ -71,7 +86,16 @@ class Namespace(NamedTuple): trash_image: bool +class NamespaceList(NamedTuple): + status: int + error_message: str + namespaces: List[Namespace] + + class NamespaceIOStats(NamedTuple): + status: int + error_message: str + subsystem_nqn: str nsid: int uuid: str bdev_name: str @@ -95,7 +119,7 @@ class NamespaceIOStats(NamedTuple): copy_latency_ticks: int max_copy_latency_ticks: int min_copy_latency_ticks: int - # io_error: List[int] + io_error: List[int] class Listener(NamedTuple): @@ -106,6 +130,12 @@ class Listener(NamedTuple): trsvcid: int = 4420 +class ListenerList(NamedTuple): + status: int + error_message: str + listeners: List[Listener] + + class Host(NamedTuple): nqn: str diff --git a/src/pybind/mgr/dashboard/services/nvmeof_client.py b/src/pybind/mgr/dashboard/services/nvmeof_client.py index 0490b2728f37a..556b59eb36554 100644 --- a/src/pybind/mgr/dashboard/services/nvmeof_client.py +++ b/src/pybind/mgr/dashboard/services/nvmeof_client.py @@ -1,7 +1,7 @@ import functools import logging from collections.abc import Iterable -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Type +from typing import Any, Callable, Dict, Generator, List, NamedTuple, Optional, Type from ..exceptions import DashboardException from .nvmeof_conf import NvmeofGatewaysConfig @@ -11,6 +11,7 @@ logger = logging.getLogger("nvmeof_client") try: import grpc # type: ignore import grpc._channel # type: ignore + from google.protobuf.json_format import MessageToDict # type: ignore from google.protobuf.message import Message # type: ignore from .proto import gateway_pb2 as pb2 # type: ignore @@ -108,7 +109,7 @@ else: def map_collection( model: Type[NamedTuple], - pick: str, + pick: str, # pylint: disable=redefined-outer-name finalize: Optional[Callable[[Message, Collection], Collection]] = None, ) -> Callable[..., Callable[..., Collection]]: def decorator(func: Callable[..., Message]) -> Callable[..., Collection]: @@ -171,3 +172,122 @@ else: func(*args, **kwargs) return wrapper + + class MaxRecursionDepthError(Exception): + pass + + def _convert(value, field_type, depth, max_depth) -> Generator: + if depth > max_depth: + raise MaxRecursionDepthError( + f"Maximum nesting depth of {max_depth} exceeded at depth {depth}.") + + if isinstance(value, dict) and hasattr(field_type, '_fields'): + # Lazily create NamedTuple for nested dicts + yield from _lazily_create_namedtuple(value, field_type, depth + 1, max_depth) + elif isinstance(value, list): + # Handle empty lists directly + if not value: + yield [] + else: + # Lazily process each item in the list based on the expected item type + item_type = field_type.__args__[0] if hasattr(field_type, '__args__') else None + processed_items = [] + for v in value: + if item_type: + processed_items.append(next(_convert(v, item_type, + depth + 1, max_depth), None)) + else: + processed_items.append(v) + yield processed_items + else: + # Yield the value as is for simple types + yield value + + def _lazily_create_namedtuple(data: Any, target_type: Type[NamedTuple], + depth: int, max_depth: int) -> Generator: + """ Lazily create NamedTuple from a dict """ + field_values = {} + for field, field_type in zip(target_type._fields, + target_type.__annotations__.values()): + # these conditions are complex since we need to navigate between dicts, + # empty dicts and objects + if isinstance(data, dict) and data.get(field) is not None: + try: + field_values[field] = next(_convert(data.get(field), field_type, + depth, max_depth), None) + except StopIteration: + return + elif hasattr(data, field): + try: + field_values[field] = next(_convert(getattr(data, field), field_type, + depth, max_depth), None) + except StopIteration: + return + else: + # If the field is missing assign None + field_values[field] = None + + namedtuple_instance = target_type(**field_values) # type: ignore + yield namedtuple_instance + + def obj_to_namedtuple(data: Any, target_type: Type[NamedTuple], + max_depth: int = 4) -> NamedTuple: + """ + Convert an object or dict to a NamedTuple, handling nesting and lists lazily. + This will raise an error if nesting depth exceeds the max depth (default 4) + to avoid bloating the memory in case of mutual references between objects. + + :param data: The input data - object or dictionary + :param target_type: The target NamedTuple type + :param max_depth: The maximum depth allowed for recursion + :return: An instance of the target NamedTuple with fields populated from the JSON + """ + + if not isinstance(target_type, type) or not hasattr(target_type, '_fields'): + raise TypeError("target_type must be a NamedTuple type.") + if isinstance(data, list): + raise TypeError("data can't be a list.") + if data is None: + raise TypeError("data can't be None.") + namedtuple_values = next(_lazily_create_namedtuple(data, target_type, 1, max_depth)) + return namedtuple_values + + def namedtuple_to_dict(obj): + if isinstance(obj, tuple) and hasattr(obj, '_asdict'): + # If it's a namedtuple, convert it to a dictionary + return {k: namedtuple_to_dict(v) for k, v in obj._asdict().items()} + if isinstance(obj, list): + # If it's a list, check each item and convert if it's a namedtuple + return [ + namedtuple_to_dict(item) + if isinstance(item, tuple) and hasattr(item, '_asdict') + else item + for item in obj + ] + return obj + + def convert_to_model(model: Type[NamedTuple]) -> Callable[..., Callable[..., Model]]: + def decorator(func: Callable[..., Message]) -> Callable[..., Model]: + @functools.wraps(func) + def wrapper(*args, **kwargs) -> Model: + message = func(*args, **kwargs) + msg_dict = MessageToDict(message, including_default_value_fields=True, + preserving_proto_field_name=True) + return namedtuple_to_dict(obj_to_namedtuple(msg_dict, model)) + + return wrapper + + return decorator + + # pylint: disable-next=redefined-outer-name + def pick(field: str, first: bool = False) -> Callable[..., Callable[..., object]]: + def decorator(func: Callable[..., Dict]) -> Callable[..., object]: + @functools.wraps(func) + def wrapper(*args, **kwargs) -> object: + model = func(*args, **kwargs) + field_to_ret = model[field] + if first: + field_to_ret = field_to_ret[0] + return field_to_ret + return wrapper + return decorator diff --git a/src/pybind/mgr/dashboard/tests/test_nvmeof_client.py b/src/pybind/mgr/dashboard/tests/test_nvmeof_client.py new file mode 100644 index 0000000000000..76fcd63f7cf50 --- /dev/null +++ b/src/pybind/mgr/dashboard/tests/test_nvmeof_client.py @@ -0,0 +1,459 @@ +from typing import Dict, List, NamedTuple, Optional +from unittest.mock import MagicMock + +import pytest + +from ..services import nvmeof_client +from ..services.nvmeof_client import MaxRecursionDepthError, convert_to_model, \ + obj_to_namedtuple, pick + + +class TestObjToNamedTuple: + def test_basic(self): + class Person(NamedTuple): + name: str + age: int + + class P: + def __init__(self, name, age): + self._name = name + self._age = age + + @property + def name(self): + return self._name + + @property + def age(self): + return self._age + + obj = P("Alice", 25) + + person = obj_to_namedtuple(obj, Person) + assert person.name == "Alice" + assert person.age == 25 + + def test_nested(self): + class Address(NamedTuple): + street: str + city: str + + class Person(NamedTuple): + name: str + age: int + address: Address + + obj = MagicMock() + obj.name = "Bob" + obj.age = 30 + obj.address.street = "456 Oak St" + obj.address.city = "Springfield" + + person = obj_to_namedtuple(obj, Person) + assert person.name == "Bob" + assert person.age == 30 + assert person.address.street == "456 Oak St" + assert person.address.city == "Springfield" + + def test_empty_obj(self): + class Person(NamedTuple): + name: str + age: int + + obj = object() + + person = obj_to_namedtuple(obj, Person) + assert person.name is None + assert person.age is None + + def test_empty_list_or_dict(self): + class Person(NamedTuple): + name: str + hobbies: List[str] + address: Dict[str, str] + + class P: + def __init__(self, name, hobbies, address): + self._name = name + self._hobbies = hobbies + self._address = address + + @property + def name(self): + return self._name + + @property + def hobbies(self): + return self._hobbies + + @property + def address(self): + return self._address + name = "George" + obj = P(name, [], {}) + + person = obj_to_namedtuple(obj, Person) + assert person.name == "George" + assert person.hobbies == [] + assert person.address == {} + + +class TestJsonToNamedTuple: + + def test_basic(self): + class Person(NamedTuple): + name: str + age: int + + json_data = { + "name": "Alice", + "age": 25 + } + + person = obj_to_namedtuple(json_data, Person) + assert person.name == "Alice" + assert person.age == 25 + + def test_nested(self): + class Address(NamedTuple): + street: str + city: str + + class Person(NamedTuple): + name: str + age: int + address: Address + + json_data = { + "name": "Bob", + "age": 30, + "address": { + "street": "456 Oak St", + "city": "Springfield" + } + } + + person = obj_to_namedtuple(json_data, Person) + assert person.name == "Bob" + assert person.age == 30 + assert person.address.street == "456 Oak St" + assert person.address.city == "Springfield" + + def test_list(self): + class Person(NamedTuple): + name: str + hobbies: List[str] + + json_data = { + "name": "Charlie", + "hobbies": ["reading", "cycling", "swimming"] + } + + person = obj_to_namedtuple(json_data, Person) + assert person.name == "Charlie" + assert person.hobbies == ["reading", "cycling", "swimming"] + + def test_nested_list(self): + class Address(NamedTuple): + street: str + city: str + + class Person(NamedTuple): + name: str + addresses: List[Address] + + json_data = { + "name": "Diana", + "addresses": [ + {"street": "789 Pine St", "city": "Oakville"}, + {"street": "101 Maple Ave", "city": "Mapleton"} + ] + } + + person = obj_to_namedtuple(json_data, Person) + assert person.name == "Diana" + assert len(person.addresses) == 2 + assert person.addresses[0].street == "789 Pine St" + assert person.addresses[1].street == "101 Maple Ave" + + def test_missing_fields(self): + class Person(NamedTuple): + name: str + age: int + address: str + + json_data = { + "name": "Eva", + "age": 40 + } + + person = obj_to_namedtuple(json_data, Person) + assert person.name == "Eva" + assert person.age == 40 + assert person.address is None + + def test_redundant_fields(self): + class Person(NamedTuple): + name: str + age: int + + json_data = { + "name": "Eva", + "age": 40, + "last_name": "Cohen" + } + + person = obj_to_namedtuple(json_data, Person) + assert person.name == "Eva" + assert person.age == 40 + + def test_max_depth_exceeded(self): + class Bla(NamedTuple): + a: str + + class Address(NamedTuple): + street: str + city: str + bla: Bla + + class Person(NamedTuple): + name: str + address: Address + + json_data = { + "name": "Frank", + "address": { + "street": "123 Elm St", + "city": "Somewhere", + "bla": { + "a": "blabla", + } + } + } + + with pytest.raises(MaxRecursionDepthError): + obj_to_namedtuple(json_data, Person, max_depth=2) + + def test_empty_json(self): + class Person(NamedTuple): + name: str + age: int + + json_data = {} + + person = obj_to_namedtuple(json_data, Person) + assert person.name is None + assert person.age is None + + def test_empty_list_or_dict(self): + class Person(NamedTuple): + name: str + hobbies: List[str] + address: Dict[str, str] + + json_data = { + "name": "George", + "hobbies": [], + "address": {} + } + + person = obj_to_namedtuple(json_data, Person) + assert person.name == "George" + assert person.hobbies == [] + assert person.address == {} + + def test_depth_within_limit(self): + class Address(NamedTuple): + street: str + city: str + + class Person(NamedTuple): + name: str + address: Address + + json_data = { + "name": "Helen", + "address": { + "street": "123 Main St", + "city": "Metropolis" + } + } + + person = obj_to_namedtuple(json_data, Person, max_depth=4) + assert person.name == "Helen" + assert person.address.street == "123 Main St" + assert person.address.city == "Metropolis" + + +class Boy(NamedTuple): + name: str + age: int + + +class Adult(NamedTuple): + name: str + age: int + children: List[str] + hobby: Optional[str] + + +class EmptyModel(NamedTuple): + pass + + +@pytest.fixture(name="person_func") +def fixture_person_func(): + @convert_to_model(Boy) + def get_person() -> dict: + return {"name": "Alice", "age": 30} + return get_person + + +@pytest.fixture(name="empty_func") +def fixture_empty_func(): + @convert_to_model(EmptyModel) + def get_empty_data() -> dict: + return {} + return get_empty_data + + +@pytest.fixture(name="disable_message_to_dict") +def fixture_disable_message_to_dict(monkeypatch): + monkeypatch.setattr(nvmeof_client, 'MessageToDict', lambda x, **kwargs: x) + + +class TestConvertToModel: + def test_basic_functionality(self, person_func, disable_message_to_dict): + # pylint: disable=unused-argument + result = person_func() + assert result == {'name': 'Alice', 'age': 30} + + def test_empty_output(self, disable_message_to_dict): + # pylint: disable=unused-argument + @convert_to_model(Boy) + def get_empty_person() -> dict: + return {} + + result = get_empty_person() + assert result == {'name': None, 'age': None} # Assuming default values for empty fields + + def test_non_dict_return_value(self, disable_message_to_dict): + # pylint: disable=unused-argument + @convert_to_model(Boy) + def get_person_list() -> list: + return ["Alice", 30] # This is an invalid return type + + with pytest.raises(TypeError): + get_person_list() + + def test_optional_fields(self, disable_message_to_dict): + # pylint: disable=unused-argument + @convert_to_model(Adult) + def get_adult() -> dict: + return {"name": "Charlie", "age": 40, "children": []} + + result = get_adult() + assert result == {'name': 'Charlie', 'age': 40, "children": [], 'hobby': None} + + def test_nested_fields(self, disable_message_to_dict): + # pylint: disable=unused-argument + @convert_to_model(Adult) + def get_adult() -> dict: + return {"name": "Charlie", "age": 40, "children": [{"name": "Alice", "age": 30}]} + + result = get_adult() + assert result == {'name': 'Charlie', 'age': 40, + "children": [{"name": "Alice", "age": 30}], 'hobby': None} + + def test_none_as_input(self, disable_message_to_dict): + # pylint: disable=unused-argument + @convert_to_model(Boy) + def get_none_person() -> dict: + return None + + with pytest.raises(TypeError): + get_none_person() + + def test_multiple_function_calls(self, person_func, disable_message_to_dict): + # pylint: disable=unused-argument + result1 = person_func() + result2 = person_func() + assert result1 == result2 + + def test_empty_model(self, empty_func, disable_message_to_dict): + # pylint: disable=unused-argument + result = empty_func() + assert result == {} + + +class TestPick: + def test_basic_field_access(self): + @pick("name") + def get_person(): + return {"name": "Alice", "height": 170} + + assert get_person() == "Alice" + + def test_first_true_on_string_field(self): + @pick("name", first=True) + def get_person(): + return {"name": "Alice"} + + assert get_person() == "A" + + def test_first_true_on_list_field(self): + @pick("tags", first=True) + def get_item(): + return {"item": "Shirt", "tags": ["red", "cotton", "medium"]} + + assert get_item() == "red" + + def test_default_field_access_on_list_field(self): + @pick("tags") + def get_item(): + return {"item": "Shirt", "tags": ["red", "cotton", "medium"]} + + assert get_item() == ["red", "cotton", "medium"] + + def test_nested_models(self): + @pick("address") + def get_person(): + return {"name": "Alice", "address": {"state": "New York", "country": "USA"}} + + assert get_person() == {"state": "New York", "country": "USA"} + + def test_field_not_present(self): + @pick("email") + def get_person(): + return {"name": "Alice", "address": {"state": "New York", "country": "USA"}} + + with pytest.raises(KeyError): + get_person() + + def test_first_true_on_empty_collection(self): + @pick("tags", first=True) + def get_item(): + return {"item": "Shirt", "tags": []} + with pytest.raises(IndexError): + get_item() + + def test_first_true_on_empty_string(self): + @pick("name", first=True) + def get_person(): + return {"name": ""} + with pytest.raises(IndexError): + get_person() + + def test_none_type_field(self): + @pick("job") + def get_person(): + return {"name": ""} + with pytest.raises(KeyError): + get_person() + + def test_none_model(self): + @pick("name") + def get_person(): + return None + with pytest.raises(TypeError): + get_person() -- 2.39.5