Learning a color quantization for neural excitation
Introduction
I have been recently playing with the diffvg lib. It allows to optimize a set of lines and get an output that optimizes specific constraints like this example.
However in this case, the number of colors is not controlled. And thus cannot be plotted wit a regular set of pens. A constraint very important if you want to draw this drawing with an Axidraw.
A simpler problem : neural style with quantization
As the library is quite complex, I decided to start small by doing the regular neuron excitation work with a limited set of colors. In order to do so, one could proceed as follow :
- The optimized image is of size [H, W, N_COLORS]
- A second tensor stores the colors for each color class
- The highest value across the third channel of the image determines the color as a given pixel
However using an argmax doesn’t make the problem differentiable, but we can of course use a softmax there.
self.class_matrix = torch.nn.Parameter(torch.tensor(np.random.randn(H, W, n_colors), requires_grad=True))
self.class_colors = torch.nn.Parameter(torch.tensor(np.random.randn(n_colors, 3) + 0.5, requires_grad=True))
img = torch.clip(F.softmax(self.class_matrix, dim=2) @ torch.sigmoid(self.class_colors), 0, 1)
Results
This first image shows the difference between the true image with the strict one color rule (left) and the relaxed version used for optimization (right). Color on the right have different shades for a given hue.
An image as seen during the optimization
The true image with classes
Some tricks
Controlling the colors with a sigmoid
torch.sigmoid(self.class_colors)
allows to keep the color in bound of what it should.
A spiky softmax
torch.clip(F.softmax(3 * self.class_matrix, dim=2)
allows to have less color cheating
A penalty for non spiky class_matrix
loss += 0.01 * torch.norm(F.softmax(pal_im.class_matrix, dim=2) ** 0.5, 1)
Here, we try to make arrays like [0.2, 0.2, 0.6]
have a higher norm than [1.0, 0, 0]
. So we take the sqaure root of the softmax output of color classes image.