-
Notifications
You must be signed in to change notification settings - Fork 5
/
dataset_mean_std_ycrcb.py
47 lines (33 loc) · 1.2 KB
/
dataset_mean_std_ycrcb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np
import cv2
from pytorch_toolbelt.utils import fs
from tqdm import tqdm
from alaska2 import idct8
from alaska2.dataset import idct8v2
def compute_mean_std(dataset):
"""
https://stats.stackexchange.com/questions/25848/how-to-sum-a-standard-deviation
"""
global_mean = np.zeros(3, dtype=np.float64)
global_var = np.zeros(3, dtype=np.float64)
n_items = 0
for image_fname in dataset:
dct_file = np.load(fs.change_extension(image_fname, ".npz"))
# This normalization roughly puts values into zero mean and unit variance
y = idct8v2(dct_file["dct_y"])
cb = idct8v2(dct_file["dct_cb"])
cr = idct8v2(dct_file["dct_cr"])
global_mean[0] += y.mean()
global_mean[1] += cb.mean()
global_mean[2] += cr.mean()
global_var[0] += y.std() ** 2
global_var[1] += cb.std() ** 2
global_var[2] += cr.std() ** 2
n_items += 1
return global_mean / n_items, np.sqrt(global_var / n_items)
def main():
dataset = fs.find_images_in_dir("/home/bloodaxe/datasets/ALASKA2/Cover")
dataset = dataset[:500]
print("YCbCr", compute_mean_std(tqdm(dataset)))
if __name__ == "__main__":
main()