-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathpolicy.py
More file actions
148 lines (125 loc) · 4.57 KB
/
policy.py
File metadata and controls
148 lines (125 loc) · 4.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from __future__ import annotations
import random
from copy import deepcopy
from typing import Optional
import torch
from torch import nn, Tensor
from torch.distributions import Categorical
from dda.operations import *
class SubPolicyStage(nn.Module):
def __init__(self,
operations: nn.ModuleList,
temperature: float,
):
super(SubPolicyStage, self).__init__()
self.operations = operations
self._weights = nn.Parameter(torch.ones(len(self.operations)))
self.temperature = temperature
def forward(self,
input: Tensor
) -> Tensor:
if self.training:
return (torch.stack([op(input) for op in self.operations]) * self.weights.view(-1, 1, 1, 1, 1)).sum(0)
else:
return self.operations[Categorical(self.weights).sample()](input)
@property
def weights(self
):
return self._weights.div(self.temperature).softmax(0)
class SubPolicy(nn.Module):
def __init__(self,
sub_policy_stage: SubPolicyStage,
operation_count: int,
):
super(SubPolicy, self).__init__()
self.stages = nn.ModuleList([deepcopy(sub_policy_stage) for _ in range(operation_count)])
def forward(self,
input: Tensor
) -> Tensor:
for stage in self.stages:
input = stage(input)
return input
class Policy(nn.Module):
def __init__(self,
operations: nn.ModuleList,
num_sub_policies: int,
temperature: float = 0.05,
operation_count: int = 2,
num_chunks: int = 4,
mean: Optional[Tensor] = None,
std: Optional[Tensor] = None,
):
super(Policy, self).__init__()
self.sub_policies = nn.ModuleList([SubPolicy(SubPolicyStage(operations, temperature), operation_count)
for _ in range(num_sub_policies)])
self.num_sub_policies = num_sub_policies
self.temperature = temperature
self.operation_count = operation_count
self.num_chunks = num_chunks
if mean is None:
self._mean, self._std = None, None
else:
self.register_buffer('_mean', mean)
self.register_buffer('_std', std)
for p in self.parameters():
nn.init.uniform_(p, 0, 1)
def forward(self,
input: Tensor
) -> Tensor:
# [0, 1] -> [-1, 1]
if self.num_chunks > 1:
out = [self._forward(inp) for inp in input.chunk(self.num_chunks)]
x = torch.cat(out, dim=0)
else:
x = self._forward(input)
if self._mean is None:
return x
else:
return self.normalize_(x)
def _forward(self,
input: Tensor
) -> Tensor:
index = random.randrange(self.num_sub_policies)
return self.sub_policies[index](input)
def normalize_(self,
input: Tensor
) -> Tensor:
# [0, 1] -> [-1, 1]
return input.add_(- self._mean[:, None, None]).div_(self._std[:, None, None])
def denormalize_(self,
input: Tensor
) -> Tensor:
# [-1, 1] -> [0, 1]
return input.mul_(self._std[:, None, None]).add_(self._mean[:, None, None])
@staticmethod
def dda_operations():
return [
ShearX(),
ShearY(),
TranslateY(),
TranslateY(),
Rotate(),
HorizontalFlip(),
Invert(),
Solarize(),
Posterize(),
Contrast(),
Saturate(),
Brightness(),
Sharpness(),
AutoContrast(),
Equalize(),
]
@staticmethod
def faster_auto_augment_policy(num_sub_policies: int,
temperature: float,
operation_count: int,
num_chunks: int,
mean: Optional[torch.Tensor] = None,
std: Optional[torch.Tensor] = None,
) -> Policy:
if mean is None or std is None:
mean = torch.ones(3) * 0.5
std = torch.ones(3) * 0.5
return Policy(nn.ModuleList(Policy.dda_operations()), num_sub_policies, temperature, operation_count,
num_chunks, mean=mean, std=std)