from collections import OrderedDict
from contextlib import contextmanager
from functools import wraps
-from ipaddress import ip_network, ip_address, ip_interface
+from ipaddress import (
+ IPv4Network,
+ IPv6Network,
+ ip_address,
+ ip_interface,
+ ip_network,
+)
from typing import (
Any,
Callable,
return out
+class SMBClusterBindIPSpec:
+ """Control what IPs the SMB services will listen on, not including
+ dynamic IPs that are managed by CTDB.
+ """
+ def __init__(
+ self,
+ # single address
+ address: Optional[str] = None,
+ # >1 address specified as a network
+ network: Optional[str] = None,
+ ) -> None:
+ self.address = address
+ self.network = network
+ self._networks: List[Union[IPv4Network, IPv6Network]] = []
+ self.validate()
+
+ def validate(self) -> None:
+ if self.address and self.network:
+ raise SpecValidationError('only one of address or network may be given')
+ if not (self.address or self.network):
+ raise SpecValidationError('one of address or network is required')
+ if self.address:
+ # verify that address is an address
+ try:
+ ip_address(self.address)
+ except ValueError as err:
+ raise SpecValidationError(
+ f'Cannot parse address {self.address}'
+ ) from err
+ # but we internallly store a list of networks
+ # this is slight bit of YAGNI violation, but I actually plan on
+ # adding IP ranges soon.
+ addr = self.network if self.network else self.address
+ try:
+ assert addr
+ self._networks = [ip_network(addr)]
+ except ValueError as err:
+ raise SpecValidationError(
+ f'Cannot parse network address {addr}'
+ ) from err
+
+ def as_networks(self) -> List[Union[IPv4Network, IPv6Network]]:
+ """Return a list of one or more IPv4 or IPv6 network objects."""
+ if not self._networks:
+ self.validate()
+ return self._networks
+
+ def as_network_strs(self) -> List[str]:
+ """Return a list of strings containing one or more network (<ip>/<mask>
+ style) values.
+ """
+ return [str(n) for n in self.as_networks()]
+
+ def __eq__(self, other: Any) -> bool:
+ try:
+ return (
+ other.address == self.address
+ and other.network == self.network
+ )
+ except AttributeError:
+ return NotImplemented
+
+ def __repr__(self) -> str:
+ if self.address:
+ return f'SMBClusterBindIPSpec(address={self.address!r})'
+ if self.network:
+ return f'SMBClusterBindIPSpec(network={self.network!r})'
+ raise ValueError('SMBClusterBindIPSpec missing address or network value')
+
+ def to_simplified(self) -> Dict[str, Any]:
+ """Return a serializable representation of SMBClusterBindIPSpec."""
+ if self.address:
+ return {'address': self.address}
+ if self.network:
+ return {'network': self.network}
+ raise ValueError('SMBClusterBindIPSpec missing address or network value')
+
+ def to_json(self) -> Dict[str, Any]:
+ """Return a JSON-compatible dict."""
+ return self.to_simplified()
+
+ @classmethod
+ def from_json(cls, spec: Dict[str, Any]) -> 'SMBClusterBindIPSpec':
+ """Convert value from a JSON-compatible dict."""
+ return cls(**spec)
+
+ @classmethod
+ def convert_list(
+ cls, arg: Optional[List[Any]]
+ ) -> Optional[List['SMBClusterBindIPSpec']]:
+ """Convert a list of values into a list of SMBClusterBindIPSpec objects.
+ Ignores None inputs returning None.
+ """
+ if arg is None:
+ return None
+ assert isinstance(arg, list)
+ out = []
+ for value in arg:
+ if isinstance(value, cls):
+ out.append(value)
+ elif hasattr(value, 'to_json'):
+ out.append(cls.from_json(value.to_json()))
+ elif isinstance(value, dict):
+ out.append(cls.from_json(value))
+ else:
+ raise SpecValidationError(
+ f"Unknown type for SMBClusterBindIPSpec: {type(value)}"
+ )
+ return out
+
+
class SMBSpec(ServiceSpec):
service_type = 'smb'
_valid_features = {'domain', 'clustered', 'cephfs-proxy'}
# custom_ports - A mapping of services to ports. If a service is
# not listed the default port will be used.
custom_ports: Optional[Dict[str, int]] = None,
+ bind_addrs: Optional[List[SMBClusterBindIPSpec]] = None,
# --- genearal tweaks ---
extra_container_args: Optional[GeneralArgList] = None,
extra_entrypoint_args: Optional[GeneralArgList] = None,
cluster_public_addrs
)
self.custom_ports = custom_ports
+ self.bind_addrs = SMBClusterBindIPSpec.convert_list(bind_addrs)
self.validate()
def validate(self) -> None:
def strict_cluster_ip_specs(self) -> List[Dict[str, Any]]:
return [s.to_strict() for s in (self.cluster_public_addrs or [])]
+ def bind_networks(self) -> List[str]:
+ """Return a list of all networks (as an addr/mask) that this service is
+ permitted to bind to.
+ """
+ out = []
+ for ba in self.bind_addrs or []:
+ out.extend(ba.as_network_strs())
+ return out
+
def to_json(self) -> "OrderedDict[str, Any]":
obj = super().to_json()
spec = obj.get('spec')
spec['cluster_public_addrs'] = [
a.to_json() for a in spec['cluster_public_addrs']
]
+ if spec and spec.get('bind_addrs'):
+ spec['bind_addrs'] = [a.to_json() for a in spec['bind_addrs']]
return obj