Skip to content

Commit

Permalink
Explore tile helper class
Browse files Browse the repository at this point in the history
  • Loading branch information
hunhoffe committed Oct 19, 2024
1 parent 00713ea commit 2650c41
Showing 1 changed file with 100 additions and 25 deletions.
125 changes: 100 additions & 25 deletions python/helpers/tensortiler/tensortiler2d.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,24 @@
import numpy as np
import matplotlib as plt


class TensorTile2DIter:
def __init__(self, sizes, strides, offset_fn, num_steps):
self.__num_steps = num_steps
self.__current_step = 0

self.__sizes = sizes
self.__strides = strides
self.__offset_fn = offset_fn

def __iter__(self):
return self

def __next__(self):
if self.__current_step == self.__num_steps:
raise StopIteration
step = self.__current_step
self.__current_step += 1
return (self.__offset_fn(step), self.__sizes, self.__strides)


class TensorTiler2D:
Expand All @@ -9,8 +28,8 @@ def __init__(
tensor_width: int,
tile_height: int,
tile_width: int,
tensor_row_major=True,
tile_row_major=True,
tensor_col_major=False,
tile_col_major=False,
):
assert tensor_height % tile_height == 0
assert tensor_width % tile_width == 0
Expand All @@ -37,34 +56,90 @@ def __init__(
1,
]

self.__tensor_col_major = tensor_col_major
if tensor_col_major:
self.__sizes[0], self.__sizes[1] = self.__sizes[1], self.__sizes[0]
self.__strides[0], self.__strides[1] = self.__strides[1], self.__strides[0]
if tile_col_major:
self.__sizes[2], self.__sizes[3] = self.__sizes[3], self.__sizes[2]
self.__strides[2], self.__strides[3] = self.__strides[3], self.__strides[2]

def tile_iter(self, chunk_height: int = 1, chunk_width: int = 1, col_major=False):
assert self.__num_tiles_per_row % chunk_width == 0
assert self.__num_tiles_per_column % chunk_height == 0

chunks_per_row = self.__num_tiles_per_row // chunk_width
chunks_per_column = self.__num_tiles_per_column // chunk_height

steps = chunks_per_row * chunks_per_column

def calc_offset(iter_num):
if not col_major:
row_idx = iter_num % chunks_per_row
col_idx = iter_num // chunks_per_row
else:
col_idx = iter_num % chunks_per_column
row_idx = iter_num // chunks_per_column

offset = row_idx * chunk_width * self.__tile_width
offset += col_idx * chunk_height * self.__tensor_width * self.__tile_height
return offset

iter_sizes = self.__sizes.copy()

size_idx = [0, 1]
if self.__tensor_col_major:
size_idx = [1, 0]
iter_sizes[size_idx[0]] = chunk_height
iter_sizes[size_idx[1]] = chunk_width
# TODO: handle row/col major

return TensorTile2DIter(
iter_sizes, self.__strides, offset_fn=calc_offset, num_steps=steps
)

def __str__(self):
return f"sizes={self.__sizes}, strides={self.__strides}"

def access_heatmap(self):
access_order_tensor = np.zeros(
self.__tensor_height * self.__tensor_width, dtype=np.int32
)
@classmethod
def access_heatmap(
cls,
tensor_height: int,
tensor_width: int,
tile_height: int,
tile_width: int,
sizes: list[int],
strides: list[int],
offset: int = 0,
show_tile: bool = True,
):
try:
import matplotlib.pyplot as plt
except:
raise ImportError(
"You must pip install matplotlib in order to use the access_heatmap() method"
)

access_order_tensor = np.zeros(tensor_height * tensor_width, dtype=np.int32)
access_count = 0
for i in range(self.__sizes[0]):
for j in range(self.__sizes[1]):
for k in range(self.__sizes[2]):
for l in range(self.__sizes[3]):
for i in range(sizes[0]):
for j in range(sizes[1]):
for k in range(sizes[2]):
for l in range(sizes[3]):
access_order_tensor[
i * self.__strides[0]
+ j * self.__strides[1]
+ k * self.__strides[2]
+ l * self.__strides[3]
offset
+ i * strides[0]
+ j * strides[1]
+ k * strides[2]
+ l * strides[3]
] = access_count
access_count += 1
access_order_tensor = access_order_tensor.reshape(
(self.__tensor_height, self.__tensor_width)
)
print("Per Tile:")
per_tile_access_order = access_order_tensor[
0 : self.__tile_height, 0 : self.__tile_width
]
plt.imshow(per_tile_access_order, cmap="hot", interpolation="nearest")
plt.show()
access_order_tensor = access_order_tensor.reshape((tensor_height, tensor_width))
if show_tile:
print("Per Tile:")
per_tile_access_order = access_order_tensor[0:tile_height, 0:tile_width]
plt.imshow(per_tile_access_order, cmap="gist_heat", interpolation="nearest")
plt.show()
print("Per Tensor:")
plt.imshow(access_order_tensor, cmap="hot", interpolation="nearest")
plt.imshow(access_order_tensor, cmap="turbo", interpolation="nearest")
plt.show()

0 comments on commit 2650c41

Please sign in to comment.