skip to Main Content

For completeness this is a text summary of what I am trying to do:

  1. Split the image into tiles.
  2. Run each tile through a new copy of the model for a set number of iterations.
  3. Feather tiles and put them into rows.
  4. Feather rows and put them back together into the original image/tensor.
  5. Maybe save the output, then split the output into tiles again.
  6. Repeat steps 2 and 3 for a set number of iterations.

I only need help with steps 1, 3, and 4. The style transfer process causes some slight differences to form in the processed tiles, so I need to blend them back together. By feathering I basically mean fading a tile into another to blur the boundaries (like in ImageMagick, Photoshop, etc…). I am trying to accomplish this blending by using Torch.linspace() to create masks, though I’m not sure if there’s a better approach.

What I am trying to accomplish is based on/inspired by: https://github.com/VaKonS/neural-style/blob/Multi-resolution/neural_style.lua, though I’m working with PyTorch. The code I am trying to implement tiling with can be found here: https://gist.github.com/ProGamerGov/e64fcb309274c2946f5a9a679ed45669, though you shouldn’t need to look at it as everything you need can be found below.

In essence this is what I am trying to do (red areas overlap with another tile):

Diagram of what I am trying accomplish. Red regions are where a tile overlaps with another tile.
.

This is what I have for code so far. Feathering and adding rows together is not yet implemented as I can’t get the individual tile feathering working yet.

import torch
from PIL import Image
import torchvision.transforms as transforms

def tile_calc(tile_size, v, d):
    max_val = max(min(tile_size*v+tile_size, d), 0)
    min_val = tile_size*v
    if abs(min_val - max_val) < tile_size:
        min_val = max_val-tile_size
    return min_val, max_val

def split_tensor(tensor, tile_size=256):
    tiles, tile_idx = [], []
    tile_size_y, tile_size_x = tile_size+8, tile_size +5 # Make H and W different for testing
    h, w = tensor.size(2), tensor.size(3)
    h_range, w_range = int(-(h // -tile_size_y)), int(-(w // -tile_size_x))

    for y in range(h_range):       
        for x in range(w_range):        
            ty, y_val = tile_calc(tile_size_y, y, h)
            tx, x_val = tile_calc(tile_size_x, x, w)

            tiles.append(tensor[:, :, ty:y_val, tx:x_val])
            tile_idx.append([ty, y_val, tx, x_val])

    w_overlap = tile_idx[0][3] - tile_idx[1][2]
    h_overlap = tile_idx[0][1] - tile_idx[w_range][0]

    if tensor.is_cuda:
        base_tensor = torch.zeros(tensor.squeeze(0).size(), device=tensor.get_device())
    else: 
        base_tensor = torch.zeros(tensor.squeeze(0).size())
    return tiles, base_tensor.unsqueeze(0), (h_range, w_range), (h_overlap, w_overlap) 

 # Feather vertically          
def feather_tiles(tensor_list, hxw, w_overlap):
    print(len(tensor_list))
    mask_list = []
    if w_overlap > 0:
        for i, tile in enumerate(tensor_list):
            if i % hxw[1] != 0:
                lin_mask = torch.linspace(0,1,w_overlap).repeat(tile.size(2),1)
                mask_part = torch.ones(tile.size(2), tile.size(3)-w_overlap)
                mask = torch.cat([lin_mask, mask_part], 1)
                mask = mask.repeat(3,1,1).unsqueeze(0)
                mask_list.append(mask)
            else:
                mask = torch.ones(tile.squeeze().size()).unsqueeze(0)
                mask_list.append(mask)
    return mask_list


def build_row(tensor_tiles, tile_masks, hxw, w_overlap, bt, tile_size):
    print(len(tensor_tiles), len(tile_masks))
    if bt.is_cuda:
        row_base = torch.ones(bt.size(1),tensor_tiles[0].size(2),bt.size(3), device=bt.get_device()).unsqueeze(0)
    else: 
        row_base = torch.ones(bt.size(1),tensor_tiles[0].size(2),bt.size(3)).unsqueeze(0)
    row_list = []
    for v in range(hxw[1]):
      row_list.append(row_base.clone())  

    num_tiles = 0
    row_val = 0
    tile_size_y, tile_size_x = tile_size+8, tile_size +5
    h, w = bt.size(2), bt.size(3)
    h_range, w_range = hxw[0], hxw[1]
    for y in range(h_range):       
        for x in range(w_range):        
            ty, y_val = tile_calc(tile_size_y, y, h)
            tx, x_val = tile_calc(tile_size_x, x, w)

            if num_tiles % hxw[1] != 0: 
                new_mean = (row_list[row_val][:, :, :, tx:x_val].mean() + tensor_tiles[num_tiles])/2
                row_list[row_val][:, :, :, tx:x_val] = row_list[row_val][:, :, :, tx:x_val] - row_list[row_val][:, :, :, tx:x_val].mean()
                tensor_tiles[num_tiles] = tensor_tiles[num_tiles] - tensor_tiles[num_tiles].mean()  

                row_list[row_val][:, :, :, tx:x_val] = (row_list[row_val][:, :, :, tx:x_val] + ( tensor_tiles[num_tiles] * tile_masks[num_tiles])) + new_mean

            else:
                row_list[row_val][:, :, :, tx:x_val] = tensor_tiles[num_tiles]          
            num_tiles+=1 
        row_val+=1          
    return row_list


def preprocess(image_name, image_size):
    image = Image.open(image_name).convert('RGB')
    if type(image_size) is not tuple:
        image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)])
    Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
    tensor = (Loader(image) * 256).unsqueeze(0)
    return tensor

def deprocess(output_tensor):
    output_tensor = output_tensor.squeeze(0).cpu() / 256
    output_tensor.clamp_(0, 1)
    Image2PIL = transforms.ToPILImage()
    image = Image2PIL(output_tensor.cpu())
    return image


input_tensor = preprocess('test.jpg', 256)

tile_tensors, base_t, hxw, ovlp = split_tensor(input_tensor, 128)
tile_masks = feather_tiles(tile_tensors, hxw, ovlp[1])
row_tensors = build_row(tile_tensors, tile_masks, hxw, ovlp[1], base_t, 128)

ft = deprocess(row_tensors[0]) # save tensor to view it 
ft.save('ft_row_0.png')

2

Answers


  1. Chosen as BEST ANSWER

    I was able to create a solution here that works for any tile size, image size, and pattern: https://github.com/ProGamerGov/neural-dream/blob/master/neural_dream/dream_tile.py

    I used masks to blend the tiles back together again.


  2. You are looking for torch.nn.functional.unfold and torch.nn.functional.fold.
    These functions allows you to apply “sliding window” operations on images with arbitrary window sizes and strides.
    This answer gives more information about these functions, and this answer gives an example of how to “blend” overlapping windows using fold.
    These references should give you the information you need to implement your blending scheme.

    Login or Signup to reply.
Please signup or login to give your own answer.
Back To Top
Search