This repository has been archived by the owner on Nov 30, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
94 lines (82 loc) · 3.42 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#
# Deep Edit
# Copyright (c) 2020 Homedeck, LLC.
#
from plasma.filters import highlights, selective_color, shadows, temperature, tint, tone_curve
from torch import cat, tensor, zeros, zeros_like, Tensor
from torch.nn import Linear, Module, ReLU, Sequential, Tanh
from torch.nn.functional import interpolate
from torchvision.models import resnet34
from torchsummary import summary
class DeepEdit (Module):
def __init__ (self):
super(DeepEdit, self).__init__()
# Model
self.model = resnet34(pretrained=True, progress=True)
in_features = self.model.fc.in_features
self.model.fc = Sequential(
Linear(in_features, 1024),
ReLU(inplace=True),
Linear(1024, 256),
ReLU(inplace=True),
Linear(256, 64),
ReLU(),
Linear(64, 10),
Tanh()
)
# Constant buffers
self.register_buffer("x_s", tensor(0.8)) # shadows
self.register_buffer("x_h", tensor(-0.9)) # highlights
self.register_buffer("t_0", tensor(-1.)) # tone curve dark control point
self.register_buffer("t_3", tensor(1.)) # tone curve white control point
self.register_buffer("selective_lum", zeros(1, 3, 1))
self.register_buffer("basis", tensor([
[1.0, 0.65, 0.0], # orange
[1.0, 1.0, 0.0], # yellow
[0.0, 1.0, 0.0] # green
]))
def forward (self, input: Tensor) -> Tensor:
weights = self.weights(input)
result = self.filter(input, weights)
return result
def weights (self, input: Tensor) -> Tensor:
"""
Compute the editing coefficients for a given image.
Parameters:
input (Tensor): Input image with shape (N,3,H,W) in range [-1., 1.].
Returns:
Tensor: Editing coefficients with shape (N,11) in range [-1., 1.].
"""
input = interpolate(input, size=(512, 512), mode="bilinear", align_corners=False)
weights = self.model(input)
return weights
def filter (self, input: Tensor, weights: Tensor) -> Tensor:
"""
Apply editing forward model.
Parameters:
input (Tensor): Input image with shape (N,3,H,W) in range [-1., 1.].
weights (Tensor): Editing coefficients with shape (N,11) in [-1., 1.].
Returns:
Tensor: Filtered image with shape (N,3,H,W) in range [-1., 1.].
"""
batch, _, _, _ = input.shape
tone_weights, chroma_weights, selective_weights = weights[:,:2], weights[:,2:4], weights[:,4:]
# Fixed
input = shadows(input, self.x_s)
input = highlights(input, self.x_h)
# Tone
controls = cat([self.t_0.expand(batch, 1), tone_weights, self.t_3.expand(batch, 1)], dim=1)
input = tone_curve(input, controls)
# Chromaticity
x_temp, x_tint = chroma_weights.split(1, dim=1)
input = temperature(input, x_temp)
input = tint(input, x_tint)
# Selective color
x_selective = selective_weights.view(-1, 3, 2) # Nx3x2
x_selective_lum = self.selective_lum.repeat(batch, 1, 1) # Nx3x1
x_selective = cat([x_selective, x_selective_lum], dim=2) # Nx3x3
input = selective_color(input, self.basis, x_selective)
return input
if __name__ == "__main__":
model = DeepEdit()
summary(model, (3, 1024, 1024), batch_size=8)