pydantic decimal 类型 自动根据 decimal_places 对小数点四舍五入(保留指定长度的小数点)
发布时间:
问题
在使用 pydantic 限制字段类型的时候, 如果某字段是 decimal ,并且我们限制 小数点 decimal_places 为某个值,假设为 2
price: Decimal = Field(max_digits=50, decimal_places=2)
但是上游传递进来的小数 可能并不符合模型的要求, 比如 上游传递进来的可能是 123.888888 55.999999 小数点超过了 2, 那么 pydantic 就会报错
解决方案
方案1:
使用 pydantic 的 model_validator,但是这个方法比较麻烦, 需要写不少代码
方案2:
自己封装一个,读取 元数据里的 decimal_places, 动态进行四舍五入 处理小数点,代码如下:
from decimal import Decimal, ROUND_HALF_UP
from pydantic import BaseModel, model_validator, Field
from typing import Optional, Any, Union
from pydantic.fields import FieldInfo
import types
import decimal
class AutoDecimalModel(BaseModel):
"""自动根据Field配置处理Decimal精度的基类"""
@model_validator(mode='before')
@classmethod
def round_decimal_fields(cls, data: Any) -> Any:
"""只处理Decimal字段的精度"""
if not isinstance(data, dict):
return data
result = data.copy()
for field_name, field_info in cls.model_fields.items():
# 跳过非Decimal字段
if not cls._is_decimal_field(field_info):
continue
# 只处理Decimal字段
decimal_places = cls._get_decimal_places_from_field(field_info)
if decimal_places is not None and field_name in result:
result[field_name] = cls._process_decimal_value(
result[field_name], decimal_places
)
return result
@classmethod
def _is_decimal_field(cls, field_info: FieldInfo) -> bool:
"""快速判断是否是Decimal字段 - 支持 Decimal | None 语法"""
annotation = field_info.annotation
# 调试信息
# print(f"检查字段类型: {annotation}, 类型: {type(annotation)}")
# 1. 直接是 Decimal 类型
if annotation is Decimal:
return True
# 2. 是 Optional[Decimal] 类型
if (hasattr(annotation, '__origin__') and
annotation.__origin__ is Optional):
args = getattr(annotation, '__args__', [])
if Decimal in args:
return True
# 3. 是 Union[Decimal, None] 类型
if (hasattr(annotation, '__origin__') and
annotation.__origin__ is Union):
args = getattr(annotation, '__args__', [])
if Decimal in args:
return True
# 4. 处理 Python 3.10+ 的 | 语法 (types.UnionType)
if hasattr(types, 'UnionType') and isinstance(annotation, types.UnionType):
args = getattr(annotation, '__args__', [])
if Decimal in args:
return True
# 5. 处理其他可能的 Union 语法变体
if (hasattr(annotation, '__args__') and
any(arg is Decimal for arg in getattr(annotation, '__args__', []))):
return True
return False
@classmethod
def _get_decimal_places_from_field(cls, field_info: FieldInfo) -> Optional[int]:
"""从Field信息中提取decimal_places配置"""
if hasattr(field_info, 'json_schema_extra'):
extra = field_info.json_schema_extra
if extra and 'decimal_places' in extra:
return extra['decimal_places']
if hasattr(field_info, 'metadata') and field_info.metadata:
for metadata in field_info.metadata:
if hasattr(metadata, 'decimal_places'):
return metadata.decimal_places
elif isinstance(metadata, dict) and 'decimal_places' in metadata:
return metadata['decimal_places']
return None
@classmethod
def _process_decimal_value(cls, value: Any, decimal_places: int) -> Any:
"""处理Decimal值的四舍五入"""
if value is None:
return None
try:
decimal_val = Decimal(str(value))
quantize_str = f"0.{'0' * decimal_places}" if decimal_places > 0 else "1"
return decimal_val.quantize(Decimal(quantize_str), rounding=ROUND_HALF_UP)
except (ValueError, TypeError, decimal.InvalidOperation):
return value
测试效果
# 测试
class TestModel(AutoDecimalModel):
price: Decimal = Field(max_digits=50, decimal_places=3)
# 运行测试
print("=== 开始测试 ===")
test_cases = ["12.3424242425", "113.4511111111"]
for case in test_cases:
try:
result = TestModel(price=case)
print(f"输入: {case} -> 输出: {result.price}")
except Exception as e:
print(f"输入: {case} -> 错误: {e}")输出效果
=== 开始测试 ===
输入: 12.3424242425 -> 输出: 12.342
输入: 113.4511111111 -> 输出: 113.451工作原理
模型验证前拦截: 使用 @model_validator(mode=‘before’) 在数据验证前进行处理
字段配置读取: 通过 cls.model_fields 获取字段的 Field 配置信息
精度自动应用: 读取 decimal_places 配置并应用到对应的 Decimal 值
四舍五入处理: 使用 Decimal.quantize() 进行精确的四舍五入
其他
希望官方早点支持这样的功能