'How to generate Pydantic model for multiple different objects
I need to have a variable covars
that contains an unknown number of entries, where each entry is one of three different custom Pydantic
models. In this case, each entry describes a variable for my application.
Specifically, I want covars
to have the following form. It is shown here for three entries, namely variable1
, variable2
and variable3
, representing the three different types of entries. Though, when deployed, the application must allow to receive more than three entries, and not all entry types need to be present in a request.
covars = {
'variable1': # type: integer
{
'guess': 1,
'min': 0,
'max': 2,
},
'variable2': # type: continuous
{
'guess': 12.2,
'min': -3.4,
'max': 30.8,
},
'variable3': # type: categorical
{
'guess': 'red',
'options': {'red', 'blue', 'green'},
}
}
I have successfully created the three different entry types as three separate Pydantic
models
import pydantic
from typing import Set, Dict, Union
class IntVariable(pydantic.BaseModel):
guess: int
min: int
max: int
class ContVariable(pydantic.BaseModel):
guess: float
min: float
max: float
class CatVariable(pydantic.BaseModel):
guess: str
options: Set[str] = {}
Notice the data type difference between IntVariable
and ContVariable
.
My question: How to make a Pydantic
model that allows combining any number of entries of types IntVariable
, ContVariable
and CatVariable
to get the output I am looking for?
The plan is to use this model to verify the data as it is being posted to the API, and then store a serialized version to the application db (using ormar
).
Solution 1:[1]
First, since you don't seem to be using pre-defined keys, you could use a custom root type, which allows you to have arbitrary key names in a pydantic model, as discussed here. Next, you could use a Union
, which allows a model attribute to accept different types (and also ignores the order when defined). Thus, you can pass a number of entries of your three models, regardless of the order.
Since IntVariable
and ContVariable
models have exactly the same number of attributes and key names, when passing float
numbers to min
and max
, they are converted to int
, as there is no way for pydantic to differentiate between the two models. On top of that, min
and max
are reserved keywords in Python; thus, it would be preferable to change them, as shown below.
from typing import Dict, Set, Union
from pydantic import BaseModel
app = FastAPI()
class IntVariable(BaseModel):
guess: int
i_min: int
i_max: int
class ContVariable(BaseModel):
guess: float
f_min: float
f_max: float
class CatVariable(BaseModel):
guess: str
options: Set[str]
class Item(BaseModel):
__root__: Union [IntVariable, ContVariable, CatVariable]
@app.post("/upload")
async def upload(covars: Dict[str, Item]):
return covars
Input sample below. Make sure to use square brackets []
, when setting the options
, as FastAPI complains if braces are used.
{
"variable1":{
"guess":1,
"i_min":0,
"i_max":2
},
"variable2":{
"guess":"orange",
"options":["orange", "yellow", "brown"]
},
"variable3":{
"guess":12.2,
"f_min":-3.4,
"f_max":30.8
},
"variable4":{
"guess":"red",
"options":["red", "blue", "green"]
},
"variable5":{
"guess":2.15,
"f_min":-1.75,
"f_max":11.8
}
}
Update
Since with the above, when a ValidationError
is raised for one of the models, errors for all three models are raised (instead of raising errors only for that specific model), one could use Discriminated Unions, as described in this answer. With Discriminated Unions, "only one explicit error is raised in case of failure". Example below:
app.py
from fastapi import FastAPI
from typing import Dict, Set, Union
from pydantic import BaseModel, Field
from typing import Literal
app = FastAPI()
class IntVariable(BaseModel):
model_type: Literal['int']
guess: int
i_min: int
i_max: int
class ContVariable(BaseModel):
model_type: Literal['cont']
guess: float
f_min: float
f_max: float
class CatVariable(BaseModel):
model_type: Literal['cat']
guess: str
options: Set[str]
class Item(BaseModel):
__root__: Union[IntVariable, ContVariable, CatVariable] = Field(..., discriminator='model_type')
@app.post("/upload")
async def upload(covars: Dict[str, Item]):
return covars
Test data
{
"variable1":{
"model_type": "int",
"guess":1,
"i_min":0,
"i_max":2
},
"variable2":{
"model_type": "cat",
"guess":"orange",
"options":["orange", "yellow", "brown"]
},
"variable3":{
"model_type": "cont",
"guess":12.2,
"f_min":-3.4,
"f_max":30.8
},
"variable4":{
"model_type": "cat",
"guess":"red",
"options":["red", "blue", "green"]
},
"variable5":{
"model_type": "cont",
"guess":2.15,
"f_min":-1.75,
"f_max":11.8
}
}
An alternative solution would be to have a dependency function, where you iterate over the dictionary and try parsing each item/entry in the dictionary using the three models within a try-catch block, similar to what described in this answer (Update 1). However, that would require either looping through all the models, or having a discriminator in the entry (such as "model_type"
above), indicating which model you should try parsing.
Solution 2:[2]
I ended up solving the problem using custom validators. Adding it here to complement the solution by @Chris.
I used a couple of other features to make this work. First, I set up the three types as Enum
to constrain the options. Secondly, I used StrictInt
, StrictFloat
and StrictStr
to circumvent the challenge that python
will convert an int
to a float
if the first option presented in e.g. guess
is float
, i.e. if I were to use guess: Union[float,int,str]
. Thirdly, I remove the input vtype
(which is of type VarType
) and replace it with another field type
of type str
using custom replacement via a root_validator
.
import ormar
import pydantic
from enum import Enum
from pydantic import Json, validator, root_validator, StrictInt, StrictFloat, StrictStr
from typing import Set, Dict, Union, Optional
import uuid
class VarType(Enum):
int = "int"
cont = "cont"
cat = "cat"
class Variable(pydantic.BaseModel):
vtype: VarType
guess: Union[StrictFloat, StrictInt, StrictStr]
min: Optional[Union[StrictFloat, StrictInt]] = None
max: Optional[Union[StrictFloat, StrictInt]] = None
options: Optional[Set[str]] = None
# this check is needed to make 'type' available for 'check_guess' validator, but it is not otherwise needed since
# VarType itself ensures type validation
@validator('vtype', allow_reuse=True)
def req_check(cls, t):
assert t.value in ['int', 'cont', 'cat'], "'vtype' must take value from set ['int', 'cont', 'cat']"
return t
# add new field called "type"
@root_validator(pre=False, allow_reuse=True)
def insert_type(cls, values):
if values['vtype'].value == 'int':
values['type'] = 'int'
elif values['vtype'].value == 'cont':
values['type'] = 'float'
elif values['vtype'].value == 'cat':
values['type'] = 'str'
return values
@root_validator(pre=True, allow_reuse=True)
def set_guessminmax_types(cls, values):
if values['vtype'] == 'int':
values['guess'] = int(values['guess'])
values['min'] = int(values['min'])
values['max'] = int(values['max'])
elif values['vtype'] == 'cont':
values['guess'] = float(values['guess'])
values['min'] = float(values['min'])
values['max'] = float(values['max'])
return values
# check right data type of 'guess'
@validator('guess', allow_reuse=True)
def check_guess_datatype(cls, g, values):
if values['vtype'].value == 'int':
assert isinstance(g, int), "data type mismatch between 'guess' and 'vtype'. Expected type 'int' from 'guess' but received " + str(
type(g))
return g
elif values['vtype'].value == 'cont':
assert isinstance(g, float), "data type mismatch between 'guess' and 'vtype'. Expected type 'float' from 'guess' but received " + str(
type(g))
return g
elif values['vtype'].value == 'cat':
assert isinstance(g, str), "data type mismatch between 'guess' and 'vtype'. Expected type 'str' from 'guess' but received " + str(
type(g))
return g
# check that 'min' is included for types 'int', 'cont'
@validator('min', allow_reuse=True)
def check_min_included(cls, m, values):
if values['vtype'].value in ['int', 'cont']:
assert m is not None
return m
# check that 'max' is included for types 'int', 'cont'
@validator('max', allow_reuse=True)
def check_max_included(cls, m, values):
if values['vtype'].value in ['int', 'cont']:
assert m is not None
return m
# check that 'options' is included for type 'cat'
@validator('options', allow_reuse=True)
def check_options_included(cls, op, values):
if values['vtype'].value == 'cat':
assert op is not None
return op
# removes all fields which have value None
@root_validator(pre=False, allow_reuse=True)
def remove_all_nones(cls, values):
values = {k: v for k, v in values.items() if v is not None}
return values
class Config:
fields = {"vtype": {"exclude": True}}
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | |
Solution 2 | svedel |