pydantic decimal 类型 自动根据 decimal_places 对小数点四舍五入(保留指定长度的小数点)

发布时间:

问题

在使用 pydantic 限制字段类型的时候, 如果某字段是 decimal ,并且我们限制 小数点 decimal_places 为某个值,假设为 2

price: Decimal = Field(max_digits=50, decimal_places=2)

但是上游传递进来的小数 可能并不符合模型的要求, 比如 上游传递进来的可能是 123.888888 55.999999 小数点超过了 2, 那么 pydantic 就会报错

解决方案

方案1:

使用 pydantic 的 model_validator,但是这个方法比较麻烦, 需要写不少代码

方案2:

自己封装一个,读取 元数据里的 decimal_places, 动态进行四舍五入 处理小数点,代码如下:

from decimal import Decimal, ROUND_HALF_UP from pydantic import BaseModel, model_validator, Field from typing import Optional, Any, Union from pydantic.fields import FieldInfo import types import decimal class AutoDecimalModel(BaseModel): """自动根据Field配置处理Decimal精度的基类""" @model_validator(mode='before') @classmethod def round_decimal_fields(cls, data: Any) -> Any: """只处理Decimal字段的精度""" if not isinstance(data, dict): return data result = data.copy() for field_name, field_info in cls.model_fields.items(): # 跳过非Decimal字段 if not cls._is_decimal_field(field_info): continue # 只处理Decimal字段 decimal_places = cls._get_decimal_places_from_field(field_info) if decimal_places is not None and field_name in result: result[field_name] = cls._process_decimal_value( result[field_name], decimal_places ) return result @classmethod def _is_decimal_field(cls, field_info: FieldInfo) -> bool: """快速判断是否是Decimal字段 - 支持 Decimal | None 语法""" annotation = field_info.annotation # 调试信息 # print(f"检查字段类型: {annotation}, 类型: {type(annotation)}") # 1. 直接是 Decimal 类型 if annotation is Decimal: return True # 2. 是 Optional[Decimal] 类型 if (hasattr(annotation, '__origin__') and annotation.__origin__ is Optional): args = getattr(annotation, '__args__', []) if Decimal in args: return True # 3. 是 Union[Decimal, None] 类型 if (hasattr(annotation, '__origin__') and annotation.__origin__ is Union): args = getattr(annotation, '__args__', []) if Decimal in args: return True # 4. 处理 Python 3.10+ 的 | 语法 (types.UnionType) if hasattr(types, 'UnionType') and isinstance(annotation, types.UnionType): args = getattr(annotation, '__args__', []) if Decimal in args: return True # 5. 处理其他可能的 Union 语法变体 if (hasattr(annotation, '__args__') and any(arg is Decimal for arg in getattr(annotation, '__args__', []))): return True return False @classmethod def _get_decimal_places_from_field(cls, field_info: FieldInfo) -> Optional[int]: """从Field信息中提取decimal_places配置""" if hasattr(field_info, 'json_schema_extra'): extra = field_info.json_schema_extra if extra and 'decimal_places' in extra: return extra['decimal_places'] if hasattr(field_info, 'metadata') and field_info.metadata: for metadata in field_info.metadata: if hasattr(metadata, 'decimal_places'): return metadata.decimal_places elif isinstance(metadata, dict) and 'decimal_places' in metadata: return metadata['decimal_places'] return None @classmethod def _process_decimal_value(cls, value: Any, decimal_places: int) -> Any: """处理Decimal值的四舍五入""" if value is None: return None try: decimal_val = Decimal(str(value)) quantize_str = f"0.{'0' * decimal_places}" if decimal_places > 0 else "1" return decimal_val.quantize(Decimal(quantize_str), rounding=ROUND_HALF_UP) except (ValueError, TypeError, decimal.InvalidOperation): return value

测试效果

# 测试 class TestModel(AutoDecimalModel): price: Decimal = Field(max_digits=50, decimal_places=3) # 运行测试 print("=== 开始测试 ===") test_cases = ["12.3424242425", "113.4511111111"] for case in test_cases: try: result = TestModel(price=case) print(f"输入: {case} -> 输出: {result.price}") except Exception as e: print(f"输入: {case} -> 错误: {e}")

输出效果

=== 开始测试 === 输入: 12.3424242425 -> 输出: 12.342 输入: 113.4511111111 -> 输出: 113.451

工作原理

模型验证前拦截: 使用 @model_validator(mode=‘before’) 在数据验证前进行处理

字段配置读取: 通过 cls.model_fields 获取字段的 Field 配置信息

精度自动应用: 读取 decimal_places 配置并应用到对应的 Decimal 值

四舍五入处理: 使用 Decimal.quantize() 进行精确的四舍五入

其他

希望官方早点支持这样的功能


2025 © 糊涂.