}
try:
- from ..services.nvmeof_client import NVMeoFClient, empty_response, \
- handle_nvmeof_error, map_collection, map_model
+ from ..services.nvmeof_client import NVMeoFClient, convert_to_model, \
+ empty_response, handle_nvmeof_error, map_collection, pick
except ImportError as e:
logger.error("Failed to import NVMeoFClient and related components: %s", e)
else:
class NVMeoFGateway(RESTController):
@EndpointDoc("Get information about the NVMeoF gateway")
@NvmeofCLICommand("nvmeof gw info")
- @map_model(model.GatewayInfo)
+ @convert_to_model(model.GatewayInfo)
@handle_nvmeof_error
def list(self, gw_group: Optional[str] = None):
return NVMeoFClient(gw_group=gw_group).stub.get_gateway_info(
@ReadPermission
@Endpoint('GET', '/version')
@NvmeofCLICommand("nvmeof gw version")
- @map_model(model.GatewayVersion)
+ @convert_to_model(model.GatewayVersion)
@handle_nvmeof_error
def version(self, gw_group: Optional[str] = None):
gw_info = NVMeoFClient(gw_group=gw_group).stub.get_gateway_info(
@ReadPermission
@Endpoint('GET', '/log_level')
@NvmeofCLICommand("nvmeof gw get_log_level")
- @map_model(model.GatewayLogLevelInfo)
+ @convert_to_model(model.GatewayLogLevelInfo)
@handle_nvmeof_error
def get_log_level(self, gw_group: Optional[str] = None):
gw_log_level = NVMeoFClient(gw_group=gw_group).stub.get_gateway_log_level(
@ReadPermission
@Endpoint('PUT', '/log_level')
@NvmeofCLICommand("nvmeof gw set_log_level")
- @map_model(model.RequestStatus)
+ @convert_to_model(model.RequestStatus)
@handle_nvmeof_error
def set_log_level(self, log_level: str, gw_group: Optional[str] = None):
log_level = log_level.lower()
@ReadPermission
@Endpoint('GET', '/log_level')
@NvmeofCLICommand("nvmeof spdk_log_level get")
- @map_model(model.SpdkNvmfLogFlagsAndLevelInfo)
+ @convert_to_model(model.SpdkNvmfLogFlagsAndLevelInfo)
@handle_nvmeof_error
def get_spdk_log_level(self, gw_group: Optional[str] = None):
spdk_log_level = NVMeoFClient(gw_group=gw_group).stub.get_spdk_nvmf_log_flags_and_level(
@ReadPermission
@Endpoint('PUT', '/log_level')
@NvmeofCLICommand("nvmeof spdk_log_level set")
- @map_model(model.RequestStatus)
+ @convert_to_model(model.RequestStatus)
@handle_nvmeof_error
def set_spdk_log_level(self, log_level: Optional[str] = None,
print_level: Optional[str] = None, gw_group: Optional[str] = None):
@ReadPermission
@Endpoint('PUT', '/log_level/disable')
@NvmeofCLICommand("nvmeof spdk_log_level disable")
- @map_model(model.RequestStatus)
+ @convert_to_model(model.RequestStatus)
@handle_nvmeof_error
def disable_spdk_log_level(self, gw_group: Optional[str] = None):
spdk_log_level = NVMeoFClient(gw_group=gw_group).stub.disable_spdk_nvmf_logs(
@APIDoc("NVMe-oF Subsystem Management API", "NVMe-oF Subsystem")
class NVMeoFSubsystem(RESTController):
@EndpointDoc("List all NVMeoF subsystems")
+ @pick(field="subsystems", first=True)
@NvmeofCLICommand("nvmeof subsystem list")
- @map_collection(model.Subsystem, pick="subsystems")
+ @convert_to_model(model.SubsystemList)
@handle_nvmeof_error
def list(self, gw_group: Optional[str] = None):
return NVMeoFClient(gw_group=gw_group).stub.list_subsystems(
"gw_group": Param(str, "NVMeoF gateway group", True, None),
},
)
+ @pick(field="subsystems", first=True)
@NvmeofCLICommand("nvmeof subsystem get")
- @map_model(model.Subsystem, first="subsystems")
+ @convert_to_model(model.SubsystemList)
@handle_nvmeof_error
def get(self, nqn: str, gw_group: Optional[str] = None):
return NVMeoFClient(gw_group=gw_group).stub.list_subsystems(
"gw_group": Param(str, "NVMeoF gateway group", True, None),
},
)
- @NvmeofCLICommand("nvmeof subsystem add")
@empty_response
+ @NvmeofCLICommand("nvmeof subsystem add")
+ @convert_to_model(model.RequestStatus)
@handle_nvmeof_error
def create(self, nqn: str, enable_ha: bool, max_namespaces: int = 1024,
gw_group: Optional[str] = None):
"gw_group": Param(str, "NVMeoF gateway group", True, None),
},
)
- @NvmeofCLICommand("nvmeof subsystem del")
@empty_response
+ @NvmeofCLICommand("nvmeof subsystem del")
+ @convert_to_model(model.RequestStatus)
@handle_nvmeof_error
def delete(self, nqn: str, force: Optional[str] = "false", gw_group: Optional[str] = None):
return NVMeoFClient(gw_group=gw_group).stub.delete_subsystem(
"gw_group": Param(str, "NVMeoF gateway group", True, None),
},
)
+ @pick("listeners")
@NvmeofCLICommand("nvmeof listener list")
- @map_collection(model.Listener, pick="listeners")
+ @convert_to_model(model.ListenerList)
@handle_nvmeof_error
def list(self, nqn: str, gw_group: Optional[str] = None):
return NVMeoFClient(gw_group=gw_group).stub.list_listeners(
"gw_group": Param(str, "NVMeoF gateway group", True, None),
},
)
- @NvmeofCLICommand("nvmeof listener add")
@empty_response
+ @NvmeofCLICommand("nvmeof listener add")
+ @convert_to_model(model.RequestStatus)
@handle_nvmeof_error
def create(
self,
"gw_group": Param(str, "NVMeoF gateway group", True, None),
},
)
- @NvmeofCLICommand("nvmeof listener del")
@empty_response
+ @NvmeofCLICommand("nvmeof listener del")
+ @convert_to_model(model.RequestStatus)
@handle_nvmeof_error
def delete(
self,
"gw_group": Param(str, "NVMeoF gateway group", True, None),
},
)
+ @pick("namespaces")
@NvmeofCLICommand("nvmeof ns list")
- @map_collection(model.Namespace, pick="namespaces")
+ @convert_to_model(model.NamespaceList)
@handle_nvmeof_error
def list(self, nqn: str, gw_group: Optional[str] = None):
return NVMeoFClient(gw_group=gw_group).stub.list_namespaces(
"gw_group": Param(str, "NVMeoF gateway group", True, None),
},
)
+ @pick("namespaces", first=True)
@NvmeofCLICommand("nvmeof ns get")
- @map_model(model.Namespace, first="namespaces")
+ @convert_to_model(model.NamespaceList)
@handle_nvmeof_error
def get(self, nqn: str, nsid: str, gw_group: Optional[str] = None):
return NVMeoFClient(gw_group=gw_group).stub.list_namespaces(
},
)
@NvmeofCLICommand("nvmeof ns get_io_stats")
- @map_model(model.NamespaceIOStats)
+ @convert_to_model(model.NamespaceIOStats)
@handle_nvmeof_error
def io_stats(self, nqn: str, nsid: str, gw_group: Optional[str] = None):
return NVMeoFClient(gw_group=gw_group).stub.namespace_get_io_stats(
},
)
@NvmeofCLICommand("nvmeof ns add")
- @map_model(model.NamespaceCreation)
+ @convert_to_model(model.NamespaceCreation)
@handle_nvmeof_error
def create(
self,
"trash_image": Param(bool, "Trash RBD image after removing namespace")
},
)
+ @pick("namespaces", first=True)
@NvmeofCLICommand("nvmeof ns update")
- @map_model(model.Namespace, first="namespaces")
+ @convert_to_model(model.NamespaceList)
@handle_nvmeof_error
def update(
self,
"force": Param(str, "Force remove the RBD image")
},
)
- @NvmeofCLICommand("nvmeof ns del")
@empty_response
+ @NvmeofCLICommand("nvmeof ns del")
+ @convert_to_model(model.RequestStatus)
@handle_nvmeof_error
def delete(
self,
"gw_group": Param(str, "NVMeoF gateway group", True, None),
},
)
- @NvmeofCLICommand("nvmeof host add")
@empty_response
+ @NvmeofCLICommand("nvmeof host add")
+ @convert_to_model(model.RequestStatus)
@handle_nvmeof_error
def create(self, nqn: str, host_nqn: str, gw_group: Optional[str] = None):
return NVMeoFClient(gw_group=gw_group).stub.add_host(
"gw_group": Param(str, "NVMeoF gateway group", True, None),
},
)
- @NvmeofCLICommand("nvmeof host del")
@empty_response
+ @NvmeofCLICommand("nvmeof host del")
+ @convert_to_model(model.RequestStatus)
@handle_nvmeof_error
def delete(self, nqn: str, host_nqn: str, gw_group: Optional[str] = None):
return NVMeoFClient(gw_group=gw_group).stub.remove_host(
"gw_group": Param(str, "NVMeoF gateway group", True, None),
},
)
+ @pick("connections")
@NvmeofCLICommand("nvmeof connection list")
- @map_collection(model.Connection, pick="connections")
+ @convert_to_model(model.ConnectionList)
@handle_nvmeof_error
def list(self, nqn: str, gw_group: Optional[str] = None):
return NVMeoFClient(gw_group=gw_group).stub.list_connections(
-from typing import NamedTuple, Optional
+from typing import List, NamedTuple, Optional
class GatewayInfo(NamedTuple):
class GatewayLogLevelInfo(NamedTuple):
status: int
error_message: str
- log_level: int
+ log_level: str
class SpdkNvmfLogFlagsAndLevelInfo(NamedTuple):
status: int
error_message: str
- log_level: int
- log_print_level: int
+ log_level: str
+ log_print_level: str
class Subsystem(NamedTuple):
max_namespaces: int
+class SubsystemList(NamedTuple):
+ status: int
+ error_message: str
+ subsystems: List[Subsystem]
+
+
class Connection(NamedTuple):
traddr: str
trsvcid: int
controller_id: int
+class ConnectionList(NamedTuple):
+ status: int
+ error_message: str
+ subsystem_nqn: str
+ connections: List[Connection]
+
+
class NamespaceCreation(NamedTuple):
+ status: int
+ error_message: str
nsid: int
trash_image: bool
+class NamespaceList(NamedTuple):
+ status: int
+ error_message: str
+ namespaces: List[Namespace]
+
+
class NamespaceIOStats(NamedTuple):
+ status: int
+ error_message: str
+ subsystem_nqn: str
nsid: int
uuid: str
bdev_name: str
copy_latency_ticks: int
max_copy_latency_ticks: int
min_copy_latency_ticks: int
- # io_error: List[int]
+ io_error: List[int]
class Listener(NamedTuple):
trsvcid: int = 4420
+class ListenerList(NamedTuple):
+ status: int
+ error_message: str
+ listeners: List[Listener]
+
+
class Host(NamedTuple):
nqn: str
import functools
import logging
from collections.abc import Iterable
-from typing import Any, Callable, Dict, List, NamedTuple, Optional, Type
+from typing import Any, Callable, Dict, Generator, List, NamedTuple, Optional, Type
from ..exceptions import DashboardException
from .nvmeof_conf import NvmeofGatewaysConfig
try:
import grpc # type: ignore
import grpc._channel # type: ignore
+ from google.protobuf.json_format import MessageToDict # type: ignore
from google.protobuf.message import Message # type: ignore
from .proto import gateway_pb2 as pb2 # type: ignore
def map_collection(
model: Type[NamedTuple],
- pick: str,
+ pick: str, # pylint: disable=redefined-outer-name
finalize: Optional[Callable[[Message, Collection], Collection]] = None,
) -> Callable[..., Callable[..., Collection]]:
def decorator(func: Callable[..., Message]) -> Callable[..., Collection]:
func(*args, **kwargs)
return wrapper
+
+ class MaxRecursionDepthError(Exception):
+ pass
+
+ def _convert(value, field_type, depth, max_depth) -> Generator:
+ if depth > max_depth:
+ raise MaxRecursionDepthError(
+ f"Maximum nesting depth of {max_depth} exceeded at depth {depth}.")
+
+ if isinstance(value, dict) and hasattr(field_type, '_fields'):
+ # Lazily create NamedTuple for nested dicts
+ yield from _lazily_create_namedtuple(value, field_type, depth + 1, max_depth)
+ elif isinstance(value, list):
+ # Handle empty lists directly
+ if not value:
+ yield []
+ else:
+ # Lazily process each item in the list based on the expected item type
+ item_type = field_type.__args__[0] if hasattr(field_type, '__args__') else None
+ processed_items = []
+ for v in value:
+ if item_type:
+ processed_items.append(next(_convert(v, item_type,
+ depth + 1, max_depth), None))
+ else:
+ processed_items.append(v)
+ yield processed_items
+ else:
+ # Yield the value as is for simple types
+ yield value
+
+ def _lazily_create_namedtuple(data: Any, target_type: Type[NamedTuple],
+ depth: int, max_depth: int) -> Generator:
+ """ Lazily create NamedTuple from a dict """
+ field_values = {}
+ for field, field_type in zip(target_type._fields,
+ target_type.__annotations__.values()):
+ # these conditions are complex since we need to navigate between dicts,
+ # empty dicts and objects
+ if isinstance(data, dict) and data.get(field) is not None:
+ try:
+ field_values[field] = next(_convert(data.get(field), field_type,
+ depth, max_depth), None)
+ except StopIteration:
+ return
+ elif hasattr(data, field):
+ try:
+ field_values[field] = next(_convert(getattr(data, field), field_type,
+ depth, max_depth), None)
+ except StopIteration:
+ return
+ else:
+ # If the field is missing assign None
+ field_values[field] = None
+
+ namedtuple_instance = target_type(**field_values) # type: ignore
+ yield namedtuple_instance
+
+ def obj_to_namedtuple(data: Any, target_type: Type[NamedTuple],
+ max_depth: int = 4) -> NamedTuple:
+ """
+ Convert an object or dict to a NamedTuple, handling nesting and lists lazily.
+ This will raise an error if nesting depth exceeds the max depth (default 4)
+ to avoid bloating the memory in case of mutual references between objects.
+
+ :param data: The input data - object or dictionary
+ :param target_type: The target NamedTuple type
+ :param max_depth: The maximum depth allowed for recursion
+ :return: An instance of the target NamedTuple with fields populated from the JSON
+ """
+
+ if not isinstance(target_type, type) or not hasattr(target_type, '_fields'):
+ raise TypeError("target_type must be a NamedTuple type.")
+ if isinstance(data, list):
+ raise TypeError("data can't be a list.")
+ if data is None:
+ raise TypeError("data can't be None.")
+ namedtuple_values = next(_lazily_create_namedtuple(data, target_type, 1, max_depth))
+ return namedtuple_values
+
+ def namedtuple_to_dict(obj):
+ if isinstance(obj, tuple) and hasattr(obj, '_asdict'):
+ # If it's a namedtuple, convert it to a dictionary
+ return {k: namedtuple_to_dict(v) for k, v in obj._asdict().items()}
+ if isinstance(obj, list):
+ # If it's a list, check each item and convert if it's a namedtuple
+ return [
+ namedtuple_to_dict(item)
+ if isinstance(item, tuple) and hasattr(item, '_asdict')
+ else item
+ for item in obj
+ ]
+ return obj
+
+ def convert_to_model(model: Type[NamedTuple]) -> Callable[..., Callable[..., Model]]:
+ def decorator(func: Callable[..., Message]) -> Callable[..., Model]:
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs) -> Model:
+ message = func(*args, **kwargs)
+ msg_dict = MessageToDict(message, including_default_value_fields=True,
+ preserving_proto_field_name=True)
+ return namedtuple_to_dict(obj_to_namedtuple(msg_dict, model))
+
+ return wrapper
+
+ return decorator
+
+ # pylint: disable-next=redefined-outer-name
+ def pick(field: str, first: bool = False) -> Callable[..., Callable[..., object]]:
+ def decorator(func: Callable[..., Dict]) -> Callable[..., object]:
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs) -> object:
+ model = func(*args, **kwargs)
+ field_to_ret = model[field]
+ if first:
+ field_to_ret = field_to_ret[0]
+ return field_to_ret
+ return wrapper
+ return decorator
--- /dev/null
+from typing import Dict, List, NamedTuple, Optional
+from unittest.mock import MagicMock
+
+import pytest
+
+from ..services import nvmeof_client
+from ..services.nvmeof_client import MaxRecursionDepthError, convert_to_model, \
+ obj_to_namedtuple, pick
+
+
+class TestObjToNamedTuple:
+ def test_basic(self):
+ class Person(NamedTuple):
+ name: str
+ age: int
+
+ class P:
+ def __init__(self, name, age):
+ self._name = name
+ self._age = age
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def age(self):
+ return self._age
+
+ obj = P("Alice", 25)
+
+ person = obj_to_namedtuple(obj, Person)
+ assert person.name == "Alice"
+ assert person.age == 25
+
+ def test_nested(self):
+ class Address(NamedTuple):
+ street: str
+ city: str
+
+ class Person(NamedTuple):
+ name: str
+ age: int
+ address: Address
+
+ obj = MagicMock()
+ obj.name = "Bob"
+ obj.age = 30
+ obj.address.street = "456 Oak St"
+ obj.address.city = "Springfield"
+
+ person = obj_to_namedtuple(obj, Person)
+ assert person.name == "Bob"
+ assert person.age == 30
+ assert person.address.street == "456 Oak St"
+ assert person.address.city == "Springfield"
+
+ def test_empty_obj(self):
+ class Person(NamedTuple):
+ name: str
+ age: int
+
+ obj = object()
+
+ person = obj_to_namedtuple(obj, Person)
+ assert person.name is None
+ assert person.age is None
+
+ def test_empty_list_or_dict(self):
+ class Person(NamedTuple):
+ name: str
+ hobbies: List[str]
+ address: Dict[str, str]
+
+ class P:
+ def __init__(self, name, hobbies, address):
+ self._name = name
+ self._hobbies = hobbies
+ self._address = address
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def hobbies(self):
+ return self._hobbies
+
+ @property
+ def address(self):
+ return self._address
+ name = "George"
+ obj = P(name, [], {})
+
+ person = obj_to_namedtuple(obj, Person)
+ assert person.name == "George"
+ assert person.hobbies == []
+ assert person.address == {}
+
+
+class TestJsonToNamedTuple:
+
+ def test_basic(self):
+ class Person(NamedTuple):
+ name: str
+ age: int
+
+ json_data = {
+ "name": "Alice",
+ "age": 25
+ }
+
+ person = obj_to_namedtuple(json_data, Person)
+ assert person.name == "Alice"
+ assert person.age == 25
+
+ def test_nested(self):
+ class Address(NamedTuple):
+ street: str
+ city: str
+
+ class Person(NamedTuple):
+ name: str
+ age: int
+ address: Address
+
+ json_data = {
+ "name": "Bob",
+ "age": 30,
+ "address": {
+ "street": "456 Oak St",
+ "city": "Springfield"
+ }
+ }
+
+ person = obj_to_namedtuple(json_data, Person)
+ assert person.name == "Bob"
+ assert person.age == 30
+ assert person.address.street == "456 Oak St"
+ assert person.address.city == "Springfield"
+
+ def test_list(self):
+ class Person(NamedTuple):
+ name: str
+ hobbies: List[str]
+
+ json_data = {
+ "name": "Charlie",
+ "hobbies": ["reading", "cycling", "swimming"]
+ }
+
+ person = obj_to_namedtuple(json_data, Person)
+ assert person.name == "Charlie"
+ assert person.hobbies == ["reading", "cycling", "swimming"]
+
+ def test_nested_list(self):
+ class Address(NamedTuple):
+ street: str
+ city: str
+
+ class Person(NamedTuple):
+ name: str
+ addresses: List[Address]
+
+ json_data = {
+ "name": "Diana",
+ "addresses": [
+ {"street": "789 Pine St", "city": "Oakville"},
+ {"street": "101 Maple Ave", "city": "Mapleton"}
+ ]
+ }
+
+ person = obj_to_namedtuple(json_data, Person)
+ assert person.name == "Diana"
+ assert len(person.addresses) == 2
+ assert person.addresses[0].street == "789 Pine St"
+ assert person.addresses[1].street == "101 Maple Ave"
+
+ def test_missing_fields(self):
+ class Person(NamedTuple):
+ name: str
+ age: int
+ address: str
+
+ json_data = {
+ "name": "Eva",
+ "age": 40
+ }
+
+ person = obj_to_namedtuple(json_data, Person)
+ assert person.name == "Eva"
+ assert person.age == 40
+ assert person.address is None
+
+ def test_redundant_fields(self):
+ class Person(NamedTuple):
+ name: str
+ age: int
+
+ json_data = {
+ "name": "Eva",
+ "age": 40,
+ "last_name": "Cohen"
+ }
+
+ person = obj_to_namedtuple(json_data, Person)
+ assert person.name == "Eva"
+ assert person.age == 40
+
+ def test_max_depth_exceeded(self):
+ class Bla(NamedTuple):
+ a: str
+
+ class Address(NamedTuple):
+ street: str
+ city: str
+ bla: Bla
+
+ class Person(NamedTuple):
+ name: str
+ address: Address
+
+ json_data = {
+ "name": "Frank",
+ "address": {
+ "street": "123 Elm St",
+ "city": "Somewhere",
+ "bla": {
+ "a": "blabla",
+ }
+ }
+ }
+
+ with pytest.raises(MaxRecursionDepthError):
+ obj_to_namedtuple(json_data, Person, max_depth=2)
+
+ def test_empty_json(self):
+ class Person(NamedTuple):
+ name: str
+ age: int
+
+ json_data = {}
+
+ person = obj_to_namedtuple(json_data, Person)
+ assert person.name is None
+ assert person.age is None
+
+ def test_empty_list_or_dict(self):
+ class Person(NamedTuple):
+ name: str
+ hobbies: List[str]
+ address: Dict[str, str]
+
+ json_data = {
+ "name": "George",
+ "hobbies": [],
+ "address": {}
+ }
+
+ person = obj_to_namedtuple(json_data, Person)
+ assert person.name == "George"
+ assert person.hobbies == []
+ assert person.address == {}
+
+ def test_depth_within_limit(self):
+ class Address(NamedTuple):
+ street: str
+ city: str
+
+ class Person(NamedTuple):
+ name: str
+ address: Address
+
+ json_data = {
+ "name": "Helen",
+ "address": {
+ "street": "123 Main St",
+ "city": "Metropolis"
+ }
+ }
+
+ person = obj_to_namedtuple(json_data, Person, max_depth=4)
+ assert person.name == "Helen"
+ assert person.address.street == "123 Main St"
+ assert person.address.city == "Metropolis"
+
+
+class Boy(NamedTuple):
+ name: str
+ age: int
+
+
+class Adult(NamedTuple):
+ name: str
+ age: int
+ children: List[str]
+ hobby: Optional[str]
+
+
+class EmptyModel(NamedTuple):
+ pass
+
+
+@pytest.fixture(name="person_func")
+def fixture_person_func():
+ @convert_to_model(Boy)
+ def get_person() -> dict:
+ return {"name": "Alice", "age": 30}
+ return get_person
+
+
+@pytest.fixture(name="empty_func")
+def fixture_empty_func():
+ @convert_to_model(EmptyModel)
+ def get_empty_data() -> dict:
+ return {}
+ return get_empty_data
+
+
+@pytest.fixture(name="disable_message_to_dict")
+def fixture_disable_message_to_dict(monkeypatch):
+ monkeypatch.setattr(nvmeof_client, 'MessageToDict', lambda x, **kwargs: x)
+
+
+class TestConvertToModel:
+ def test_basic_functionality(self, person_func, disable_message_to_dict):
+ # pylint: disable=unused-argument
+ result = person_func()
+ assert result == {'name': 'Alice', 'age': 30}
+
+ def test_empty_output(self, disable_message_to_dict):
+ # pylint: disable=unused-argument
+ @convert_to_model(Boy)
+ def get_empty_person() -> dict:
+ return {}
+
+ result = get_empty_person()
+ assert result == {'name': None, 'age': None} # Assuming default values for empty fields
+
+ def test_non_dict_return_value(self, disable_message_to_dict):
+ # pylint: disable=unused-argument
+ @convert_to_model(Boy)
+ def get_person_list() -> list:
+ return ["Alice", 30] # This is an invalid return type
+
+ with pytest.raises(TypeError):
+ get_person_list()
+
+ def test_optional_fields(self, disable_message_to_dict):
+ # pylint: disable=unused-argument
+ @convert_to_model(Adult)
+ def get_adult() -> dict:
+ return {"name": "Charlie", "age": 40, "children": []}
+
+ result = get_adult()
+ assert result == {'name': 'Charlie', 'age': 40, "children": [], 'hobby': None}
+
+ def test_nested_fields(self, disable_message_to_dict):
+ # pylint: disable=unused-argument
+ @convert_to_model(Adult)
+ def get_adult() -> dict:
+ return {"name": "Charlie", "age": 40, "children": [{"name": "Alice", "age": 30}]}
+
+ result = get_adult()
+ assert result == {'name': 'Charlie', 'age': 40,
+ "children": [{"name": "Alice", "age": 30}], 'hobby': None}
+
+ def test_none_as_input(self, disable_message_to_dict):
+ # pylint: disable=unused-argument
+ @convert_to_model(Boy)
+ def get_none_person() -> dict:
+ return None
+
+ with pytest.raises(TypeError):
+ get_none_person()
+
+ def test_multiple_function_calls(self, person_func, disable_message_to_dict):
+ # pylint: disable=unused-argument
+ result1 = person_func()
+ result2 = person_func()
+ assert result1 == result2
+
+ def test_empty_model(self, empty_func, disable_message_to_dict):
+ # pylint: disable=unused-argument
+ result = empty_func()
+ assert result == {}
+
+
+class TestPick:
+ def test_basic_field_access(self):
+ @pick("name")
+ def get_person():
+ return {"name": "Alice", "height": 170}
+
+ assert get_person() == "Alice"
+
+ def test_first_true_on_string_field(self):
+ @pick("name", first=True)
+ def get_person():
+ return {"name": "Alice"}
+
+ assert get_person() == "A"
+
+ def test_first_true_on_list_field(self):
+ @pick("tags", first=True)
+ def get_item():
+ return {"item": "Shirt", "tags": ["red", "cotton", "medium"]}
+
+ assert get_item() == "red"
+
+ def test_default_field_access_on_list_field(self):
+ @pick("tags")
+ def get_item():
+ return {"item": "Shirt", "tags": ["red", "cotton", "medium"]}
+
+ assert get_item() == ["red", "cotton", "medium"]
+
+ def test_nested_models(self):
+ @pick("address")
+ def get_person():
+ return {"name": "Alice", "address": {"state": "New York", "country": "USA"}}
+
+ assert get_person() == {"state": "New York", "country": "USA"}
+
+ def test_field_not_present(self):
+ @pick("email")
+ def get_person():
+ return {"name": "Alice", "address": {"state": "New York", "country": "USA"}}
+
+ with pytest.raises(KeyError):
+ get_person()
+
+ def test_first_true_on_empty_collection(self):
+ @pick("tags", first=True)
+ def get_item():
+ return {"item": "Shirt", "tags": []}
+ with pytest.raises(IndexError):
+ get_item()
+
+ def test_first_true_on_empty_string(self):
+ @pick("name", first=True)
+ def get_person():
+ return {"name": ""}
+ with pytest.raises(IndexError):
+ get_person()
+
+ def test_none_type_field(self):
+ @pick("job")
+ def get_person():
+ return {"name": ""}
+ with pytest.raises(KeyError):
+ get_person()
+
+ def test_none_model(self):
+ @pick("name")
+ def get_person():
+ return None
+ with pytest.raises(TypeError):
+ get_person()