For completeness this is a text summary of what I am trying to do:
- Split the image into tiles.
- Run each tile through a new copy of the model for a set number of iterations.
- Feather tiles and put them into rows.
- Feather rows and put them back together into the original image/tensor.
- Maybe save the output, then split the output into tiles again.
- Repeat steps 2 and 3 for a set number of iterations.
I only need help with steps 1, 3, and 4. The style transfer process causes some slight differences to form in the processed tiles, so I need to blend them back together. By feathering I basically mean fading a tile into another to blur the boundaries (like in ImageMagick, Photoshop, etc…). I am trying to accomplish this blending by using Torch.linspace() to create masks, though I’m not sure if there’s a better approach.
What I am trying to accomplish is based on/inspired by: https://github.com/VaKonS/neural-style/blob/Multi-resolution/neural_style.lua, though I’m working with PyTorch. The code I am trying to implement tiling with can be found here: https://gist.github.com/ProGamerGov/e64fcb309274c2946f5a9a679ed45669, though you shouldn’t need to look at it as everything you need can be found below.
In essence this is what I am trying to do (red areas overlap with another tile):
This is what I have for code so far. Feathering and adding rows together is not yet implemented as I can’t get the individual tile feathering working yet.
import torch
from PIL import Image
import torchvision.transforms as transforms
def tile_calc(tile_size, v, d):
max_val = max(min(tile_size*v+tile_size, d), 0)
min_val = tile_size*v
if abs(min_val - max_val) < tile_size:
min_val = max_val-tile_size
return min_val, max_val
def split_tensor(tensor, tile_size=256):
tiles, tile_idx = [], []
tile_size_y, tile_size_x = tile_size+8, tile_size +5 # Make H and W different for testing
h, w = tensor.size(2), tensor.size(3)
h_range, w_range = int(-(h // -tile_size_y)), int(-(w // -tile_size_x))
for y in range(h_range):
for x in range(w_range):
ty, y_val = tile_calc(tile_size_y, y, h)
tx, x_val = tile_calc(tile_size_x, x, w)
tiles.append(tensor[:, :, ty:y_val, tx:x_val])
tile_idx.append([ty, y_val, tx, x_val])
w_overlap = tile_idx[0][3] - tile_idx[1][2]
h_overlap = tile_idx[0][1] - tile_idx[w_range][0]
if tensor.is_cuda:
base_tensor = torch.zeros(tensor.squeeze(0).size(), device=tensor.get_device())
else:
base_tensor = torch.zeros(tensor.squeeze(0).size())
return tiles, base_tensor.unsqueeze(0), (h_range, w_range), (h_overlap, w_overlap)
# Feather vertically
def feather_tiles(tensor_list, hxw, w_overlap):
print(len(tensor_list))
mask_list = []
if w_overlap > 0:
for i, tile in enumerate(tensor_list):
if i % hxw[1] != 0:
lin_mask = torch.linspace(0,1,w_overlap).repeat(tile.size(2),1)
mask_part = torch.ones(tile.size(2), tile.size(3)-w_overlap)
mask = torch.cat([lin_mask, mask_part], 1)
mask = mask.repeat(3,1,1).unsqueeze(0)
mask_list.append(mask)
else:
mask = torch.ones(tile.squeeze().size()).unsqueeze(0)
mask_list.append(mask)
return mask_list
def build_row(tensor_tiles, tile_masks, hxw, w_overlap, bt, tile_size):
print(len(tensor_tiles), len(tile_masks))
if bt.is_cuda:
row_base = torch.ones(bt.size(1),tensor_tiles[0].size(2),bt.size(3), device=bt.get_device()).unsqueeze(0)
else:
row_base = torch.ones(bt.size(1),tensor_tiles[0].size(2),bt.size(3)).unsqueeze(0)
row_list = []
for v in range(hxw[1]):
row_list.append(row_base.clone())
num_tiles = 0
row_val = 0
tile_size_y, tile_size_x = tile_size+8, tile_size +5
h, w = bt.size(2), bt.size(3)
h_range, w_range = hxw[0], hxw[1]
for y in range(h_range):
for x in range(w_range):
ty, y_val = tile_calc(tile_size_y, y, h)
tx, x_val = tile_calc(tile_size_x, x, w)
if num_tiles % hxw[1] != 0:
new_mean = (row_list[row_val][:, :, :, tx:x_val].mean() + tensor_tiles[num_tiles])/2
row_list[row_val][:, :, :, tx:x_val] = row_list[row_val][:, :, :, tx:x_val] - row_list[row_val][:, :, :, tx:x_val].mean()
tensor_tiles[num_tiles] = tensor_tiles[num_tiles] - tensor_tiles[num_tiles].mean()
row_list[row_val][:, :, :, tx:x_val] = (row_list[row_val][:, :, :, tx:x_val] + ( tensor_tiles[num_tiles] * tile_masks[num_tiles])) + new_mean
else:
row_list[row_val][:, :, :, tx:x_val] = tensor_tiles[num_tiles]
num_tiles+=1
row_val+=1
return row_list
def preprocess(image_name, image_size):
image = Image.open(image_name).convert('RGB')
if type(image_size) is not tuple:
image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)])
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
tensor = (Loader(image) * 256).unsqueeze(0)
return tensor
def deprocess(output_tensor):
output_tensor = output_tensor.squeeze(0).cpu() / 256
output_tensor.clamp_(0, 1)
Image2PIL = transforms.ToPILImage()
image = Image2PIL(output_tensor.cpu())
return image
input_tensor = preprocess('test.jpg', 256)
tile_tensors, base_t, hxw, ovlp = split_tensor(input_tensor, 128)
tile_masks = feather_tiles(tile_tensors, hxw, ovlp[1])
row_tensors = build_row(tile_tensors, tile_masks, hxw, ovlp[1], base_t, 128)
ft = deprocess(row_tensors[0]) # save tensor to view it
ft.save('ft_row_0.png')
2
Answers
I was able to create a solution here that works for any tile size, image size, and pattern: https://github.com/ProGamerGov/neural-dream/blob/master/neural_dream/dream_tile.py
I used masks to blend the tiles back together again.
You are looking for
torch.nn.functional.unfold
andtorch.nn.functional.fold
.These functions allows you to apply “sliding window” operations on images with arbitrary window sizes and strides.
This answer gives more information about these functions, and this answer gives an example of how to “blend” overlapping windows using
fold
.These references should give you the information you need to implement your blending scheme.