'Combining a descriptor class with dataclass and field
I am using a dataclass and field to pass in a default value. When an argument is provided I want to validate it using a descriptor class.
Is there any way to utilize the benefits of field (repr, default, init, etc) while getting the validator benefits of a descriptor class?
from dataclasses import dataclass, field
class Descriptor:
def __init__(self, default):
self.default = default
def __set_name__(self, owner, name):
self.name = name
def __get__(self, obj, objtype=None):
if obj:
return vars(obj).get(self.name)
else:
return None
def __set__(self, obj, value):
if not value:
value = self.default
else:
value = field(default=int(value), repr=False)
vars(obj)[self.name] = value
@dataclass
class Person:
age: str = Descriptor(field(default=3, repr=False))
# Many additional attributes
# using same descriptor class
p = Person()
r = Person(2.37)
Solution 1:[1]
Your idea is ok, but isn't pretty and doesn't work
My first idea was for Descriptor
to inherit dataclasses.Field
but that doesn't work either
What you should do is somewhat roundabout
But, that is how things should be done even without dataclasses
It is to make 2 attributes age
and _age
import dataclasses as dc
class GetSet:
def __init__(self, predicate = lambda value: True):
self.predicate = predicate
def __set__(self, obj, value):
if isinstance(value, GetSet):
return # init will try to assign `obj.age = GetSet(...)`, but we will ignore that
# this actually happens because it thinks `GetSet` should be default argument
if not self.predicate(value):
raise ValueError
obj._age = value
def __get__(self, obj, cls):
if obj is not None: # called obj.age
return obj._age
return self # called Person.age
@dc.dataclass
class Person:
# this one doesn't go to init, hence init=False, rest is as you wish
_age: int = dc.field(default = 20, repr=False, init=False)
# this one goes to init, if that is not desired, just remove annotation `: int`
age: int = GetSet(lambda age: age >= 18) # this attribute will not be allowed to be under 18
and alternative is to use built-in property
@dc.dataclass
class Person:
_age: int = dc.field(default = 20, repr=False, init=False)
age: int
@property
def age(self):
return self._age
@age.setter
def age(self, value):
if isinstance(value, property):
return # we still need to ignore this
if value < 18:
raise ValueError('no underaged')
self._age = value
Solution 2:[2]
There is a way to get all the benefits. Note that dataclass is generating code for you, so you can do modifications by inheritance.
from dataclasses import dataclass
from dataclasses import field
from typing import Union
class Descriptor:
def __set_name__(self, owner, name):
self.name = name
def __get__(self, obj, obj_type=None):
if obj:
return vars(obj)[self.name]
else:
return None
def __set__(self, obj, value):
if isinstance(value, str):
value = int(value)
vars(obj)[self.name] = value
@dataclass
class PersonData:
name: str
age: Union[int, str] = field(default=3, repr=False)
# Many additional attributes you want to get the benefit of field for
class Person(PersonData):
age = Descriptor()
# Many additional attributes you want to use descriptors for
if __name__ == '__main__':
p = Person('Mike')
r = Person('Mary', 2)
print(p, p.age)
print(r, r.age)
r.age = 5
print(p, p.age)
print(r, r.age)
You will get the following print:
Person(name='Mike') 3
Person(name='Mary') 2
Person(name='Mike') 3
Person(name='Mary') 5
You can see that the default value defined in PersonData
is posted to Person
automatically. There is no need to store default
in Descriptor
anymore.
Solution 3:[3]
You can use a descriptor with a dataclasses.field
(at least after this bug was fixed).
There are just a few things that needed to be changed in your code:
Starting in the dataclass
, you had the order wrong as to which object is calling which object. The field object is what should be attached to the dataclass
with the default of the Descriptor
:
@dataclass
class Person:
age: str = field(default=Descriptor(default=3), repr=False)
Next in Descriptor.__set__
, when the age
argument is not provided to the constructor, the value
argument will actually be the instance of the Descriptor
class. So we need to change the guard to see if value
is self
:
class Descriptor:
...
def __set__(self, obj, value):
if value is self:
value = self.default
....
Finally, I made one more change to echo the patterns I've seen in the python ecosystem: using the getattr
and setattr
functions for getting and setting attributes on classes.
Unfortunately, this introduced an infinite recursion bug, so I changed the place the value is stored on the Person
object to _age
.
All that being said, this works as you intended:
from dataclasses import dataclass, field
class Descriptor:
def __init__(self, default):
self.default = default
def __set_name__(self, owner, name):
self.private_name = '_' + name
def __get__(self, obj, objtype=None):
return getattr(obj, self.private_name)
def __set__(self, obj, value):
if value is self:
value = self.default
else:
value = int(value)
setattr(obj, self.private_name, value)
@dataclass
class Person:
age: str = field(default=Descriptor(default=3), repr=False)
# Many additional attributes
# using same descriptor class
r = Person(2.37)
assert r.age == 2
p = Person()
assert p.age == 3
print(r)
print(p)
print(vars(p))
Solution 4:[4]
Using descriptor with dataclass is quite tricky and require some ugly hack to make it work (to support default_factory
, frozen
,...). I have some example code working in: https://github.com/google/etils/blob/main/etils/edc/field_utils.py
from etils import edc # pip install etils[edc]
@dataclasses.dataclass
class A:
path: epath.Path = edc.field(validate=epath.Path)
x: int = edc.field(validate=int)
y: int = edc.field(validate=lambda x: -x, default=5)
a = A(
path='/some/path' # Inputs auto-normalized `str` -> `epath.Path`
x='123',
)
assert isinstance(a.path, epath.Path)
assert a.x == 123
assert a.y == -5
Here is the implementation:
"""Field utils."""
from __future__ import annotations
import dataclasses
import typing
from typing import Any, Callable, Generic, Optional, Type, TypeVar
_Dataclass = Any
_In = Any
_Out = Any
_InT = TypeVar('_InT')
_OutT = TypeVar('_OutT')
def field(
*,
validate: Optional[Callable[[_In], _OutT]] = None,
**kwargs: Any,
) -> dataclasses.Field[_OutT]:
"""Like `dataclasses.field`, but allow `validator`.
Args:
validate: A callable `(x) -> x` called each time the variable is assigned.
**kwargs: Kwargs forwarded to `dataclasses.field`
Returns:
The field.
"""
if validate is None:
return dataclasses.field(**kwargs)
else:
field_ = _Field(validate=validate, field_kwargs=kwargs)
return typing.cast(dataclasses.Field, field_)
class _Field(Generic[_InT, _OutT]):
"""Field descriptor."""
def __init__(
self,
validate: Callable[[_InT], _OutT],
field_kwargs: dict[str, Any],
) -> None:
"""Constructor.
Args:
validate: A callable called each time the variable is assigned.
field_kwargs: Kwargs forwarded to `dataclasses.field`
"""
# Attribute name and objtype refer to the object in which the descriptor
# is applied. E.g. if `A.x = edc.field()`:
# * _attribute_name = 'x'
# * _objtype = A
self._attribute_name: Optional[str] = None
self._objtype: Optional[Type[_Dataclass]] = None
self._validate = validate
self._field_kwargs = field_kwargs
# Whether `__get__` has not been called yet. See `__get__` for details.
self._first_getattr_call: bool = True
def __set_name__(self, objtype: Type[_Dataclass], name: str) -> None:
"""Bind the descriptor to the class (PEP 487)."""
self._objtype = objtype
self._attribute_name = name
def __get__(
self,
obj: Optional[_Dataclass],
objtype: Optional[Type[_Dataclass]] = None,
) -> _OutT:
"""Called when `MyDataclass.x` or `my_dataclass.x`."""
# Called as `MyDataclass.my_attribute`
if obj is None:
if self._first_getattr_call:
# Count the number of times `dataclasses.dataclass(cls)` calls
# `getattr(cls, f.name)`.
# The first time, we return a `dataclasses.Field` to let dataclass
# do the magic.
# The second time, `dataclasses.dataclass` delete the descriptor if
# `isinstance(getattr(cls, f.name, None), Field)`. So it is very
# important to return anything except a `dataclasses.Field`.
# This rely on implementation detail, but seems to hold for python
# 3.6-3.10.
self._first_getattr_call = False
return dataclasses.field(**self._field_kwargs)
else:
# TODO(epot): Could better handle default value: Either by returning
# the default value, or raising an AttributeError. Currently, we just
# return the descriptor:
# assert isinstance(MyDataclass.my_attribute, _Field)
return self
else:
# Called as `my_dataclass.my_path`
return _getattr(obj, self._attribute_name)
def __set__(self, obj: _Dataclass, value: _InT) -> None:
"""Called as `my_dataclass.x = value`."""
# Validate the value during assignement
_setattr(obj, self._attribute_name, self._validate(value))
# Because there is one instance of the `_Field` per class, shared across all
# class instances, we need to store the per-object state somewhere.
# The simplest is to attach the state in an extra `dict[str, value]`:
# `_dataclass_field_values`.
def _getattr(
obj: _Dataclass,
attribute_name: str,
) -> _Out:
"""Returns the `obj.attribute_name`."""
_init_dataclass_state(obj)
# Accessing the attribute before it was set (e.g. before super().__init__)
if attribute_name not in obj._dataclass_field_values: # pylint: disable=protected-access
raise AttributeError(
f"type object '{type(obj).__qualname__}' has no attribute "
f"'{attribute_name}'")
else:
return obj._dataclass_field_values[attribute_name] # pylint: disable=protected-access
def _setattr(
obj: _Dataclass,
attribute_name: str,
value: _In,
) -> None:
"""Set the `obj.attribute_name = value`."""
# Note: In `dataclasses.dataclass(frozen=True)`, obj.__setattr__ will
# correctly raise a `FrozenInstanceError` before `DataclassField.__set__` is
# called.
_init_dataclass_state(obj)
obj._dataclass_field_values[attribute_name] = value # pylint: disable=protected-access
def _init_dataclass_state(obj: _Dataclass) -> None:
"""Initialize the object state containing all DataclassField values."""
if not hasattr(obj, '_dataclass_field_values'):
# Use object.__setattr__ for frozen dataclasses
object.__setattr__(obj, '_dataclass_field_values', {})
There might be simpler way, but this one was fully tested.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | |
Solution 2 | daizhirui |
Solution 3 | |
Solution 4 | Conchylicultor |