|
| 1 | +import operator |
| 2 | + |
1 | 3 | from datetime import timedelta
|
2 | 4 | from mongoengine.base import BaseField
|
3 |
| -from mongoengine.fields import IntField, StringField, EmailField |
| 5 | +from mongoengine.fields import EmailField, IntField, ListField, StringField |
| 6 | + |
| 7 | +try: |
| 8 | + from functools import reduce |
| 9 | +except ImportError: |
| 10 | + # reduce is a builtin in Python2 |
| 11 | + pass |
4 | 12 |
|
5 | 13 |
|
6 | 14 | class TimedeltaField(BaseField):
|
@@ -105,6 +113,33 @@ class IntEnumField(EnumField, IntField):
|
105 | 113 | pass
|
106 | 114 |
|
107 | 115 |
|
| 116 | +class IntFlagField(ListField): |
| 117 | + def __init__(self, enum, **kwargs): |
| 118 | + super(IntFlagField, self).__init__(IntEnumField(enum), **kwargs) |
| 119 | + |
| 120 | + def __get__(self, instance, owner): |
| 121 | + if instance is None: |
| 122 | + return self |
| 123 | + |
| 124 | + return self.field.enum(reduce( |
| 125 | + operator.or_, instance._data.get(self.name, []), 0)) |
| 126 | + |
| 127 | + def __set__(self, instance, value): |
| 128 | + # copy mongoengine |
| 129 | + if value is None: |
| 130 | + if self.null: |
| 131 | + value = None |
| 132 | + elif self.default is not None: |
| 133 | + value = self.default |
| 134 | + if callable(value): |
| 135 | + value = value() |
| 136 | + |
| 137 | + if value is not None and not isinstance(value, list): |
| 138 | + value = [i for i in self.field.enum if i and i & value == i] |
| 139 | + |
| 140 | + super(IntFlagField, self).__set__(instance, value) |
| 141 | + |
| 142 | + |
108 | 143 | class StringEnumField(EnumField, StringField):
|
109 | 144 | """A variation on :class:`EnumField` for only string containing enumeration.
|
110 | 145 | """
|
|
0 commit comments