Deep Hebbian Image Encoder


Deep Hebbian Image Encoder

This project explores biologically-inspired Hebbian methods to train stackable image encoders. These models can organize visual information without traditional loss functions or backpropagation. Afterward, the learned image embeddings are visualized using UMAP + K-Means Clustering.

Below is a sample clustering of learned image embeddings from the Tiny ImageNet dataset.

Source code on Github

Why Hebbian? (Biological Analogy)

Unlike traditional neural networks that rely on error signals and gradient descent, Hebbian models operate more like biological brains. There's no supervision, no target output, no error propagation. Just forward input, and strengthening connections between neurons that activate together.

"Neurons that fire together, wire together."

This means the network doesn't know what the correct answer is, it just recognizes and reinforces co-occurrence in the data. If two features appear frequently at the same time, their connection strengthens. Over time, this results in internal representations that reflect the structure of the data, without needing labels or supervision.

This approach mirrors parts of the brain's learning strategy, where connections are locally updated based on experience, rather than global goals. It's simpler, and more biologically plausible, and that simplicity might be useful for efficient or modular systems in the future.

Dataset Preparation

Images are resized to 64ร—64 for standardization across datasets. For the Pokรฉmon dataset, we extract sprites from a transparent PNG grid:

sprites = load_spritesheet("pokemon_all_transparent.png", sprite_size=(64, 64), tile_size=(96, 96))

For Tiny ImageNet:

dataset = ImageFolder("../data/tiny_imagenet/train", transform=transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor()
]))

Hebbian Encoder Architecture

Each Hebbian encoder layer applies:

  • A %%3 \times 3%% convolution with stride 2
  • ReLU activation
  • L2 normalization over spatial dimensions
  • Lateral weights updated by Hebbian rule:

$$ \Delta W_{ij} = \eta \cdot \langle a_i \cdot a_j \rangle $$

Where:

  • %%\Delta W_{ij}%% is the change in weight from unit %%j%% to %%i%%
  • %%\eta%% is the learning rate
  • %%a_i%% and %%a_j%% are the activations of units %%i%% and %%j%% respectively
  • %%\langle a_i \cdot a_j \rangle%% denotes the batch-averaged outer product

This update promotes co-activation patterns. Inhibition is enforced by subtracting mean activation:

hebbian = torch.einsum("bni,bnj->nij", act_flat, act_flat)  # outer product over batch
delta = 0.001 * hebbian.mean(dim=0)
self.lateral_weights.data += delta.clamp(-1.0, 1.0)

Overall Network Architecture

Input Image (64ร—64ร—3/4)
       โ”‚
       โ–ผ
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”    Layer 1: 3/4โ†’16 channels, 64ร—64โ†’32ร—32
โ”‚  HebbianLayer1  โ”‚    
โ”‚   Conv + Hebb   โ”‚    โ”Œโ”€ Conv2d(3/4โ†’16, k=3, s=2) 
โ”‚                 โ”‚    โ”œโ”€ ReLU + L2 Norm
โ”‚  [32ร—32ร—16]     โ”‚    โ”œโ”€ Lateral Connections (16ร—1024ร—1024)
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜    โ””โ”€ Mean Subtraction (Inhibition)
       โ”‚
       โ–ผ
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”    Layer 2: 16โ†’32 channels, 32ร—32โ†’16ร—16
โ”‚  HebbianLayer2  โ”‚    
โ”‚   Conv + Hebb   โ”‚    โ”œโ”€ Lateral Connections (32ร—256ร—256)
โ”‚  [16ร—16ร—32]     โ”‚    
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
       โ”‚
       โ–ผ
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”    Layer 3: 32โ†’64 channels, 16ร—16โ†’8ร—8
โ”‚  HebbianLayer3  โ”‚    
โ”‚   Conv + Hebb   โ”‚    โ”œโ”€ Lateral Connections (64ร—64ร—64)
โ”‚   [8ร—8ร—64]      โ”‚    
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
       โ”‚
       โ–ผ
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”    Layer 4: 64โ†’128 channels, 8ร—8โ†’4ร—4
โ”‚  HebbianLayer4  โ”‚    
โ”‚   Conv + Hebb   โ”‚    โ”œโ”€ Lateral Connections (128ร—16ร—16)
โ”‚   [4ร—4ร—128]     โ”‚    
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
       โ”‚
       โ–ผ
    Flatten
  [2048 features]
       โ”‚
       โ–ผ
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚  UMAP Reduce    โ”‚ โ”€โ”€โ–บ 2D Visualization
โ”‚   2048 โ†’ 2D     โ”‚     
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

Spatial-Aware Lateral Connections Explained

Single Channel Spatial Layout:

Original 4ร—4 feature map:          Flattened to 16 positions:
โ”Œโ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”         โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚ aโ‚€  โ”‚ aโ‚  โ”‚ aโ‚‚  โ”‚ aโ‚ƒ  โ”‚         โ”‚ aโ‚€ aโ‚ aโ‚‚ aโ‚ƒ aโ‚„ ...  โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค   โ”€โ”€โ”€โ”€โ–บ โ”‚                     โ”‚
โ”‚ aโ‚„  โ”‚ aโ‚…  โ”‚ aโ‚†  โ”‚ aโ‚‡  โ”‚         โ”‚ ... aโ‚โ‚‚ aโ‚โ‚ƒ aโ‚โ‚„ aโ‚โ‚… โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค         โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
โ”‚ aโ‚ˆ  โ”‚ aโ‚‰  โ”‚ aโ‚โ‚€ โ”‚ aโ‚โ‚ โ”‚         
โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค         
โ”‚ aโ‚โ‚‚ โ”‚ aโ‚โ‚ƒ โ”‚ aโ‚โ‚„ โ”‚ aโ‚โ‚… โ”‚         
โ””โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”˜

Lateral Weight Matrix (16ร—16 for this channel):

     aโ‚€  aโ‚  aโ‚‚  aโ‚ƒ  aโ‚„  aโ‚…  aโ‚†  aโ‚‡  aโ‚ˆ  aโ‚‰  aโ‚โ‚€ aโ‚โ‚ aโ‚โ‚‚ aโ‚โ‚ƒ aโ‚โ‚„ aโ‚โ‚…
aโ‚€  โ”‚wโ‚€โ‚€ wโ‚€โ‚ wโ‚€โ‚‚ wโ‚€โ‚ƒ wโ‚€โ‚„ wโ‚€โ‚… wโ‚€โ‚† wโ‚€โ‚‡ wโ‚€โ‚ˆ wโ‚€โ‚‰ wโ‚€โ‚โ‚€wโ‚€โ‚โ‚wโ‚€โ‚โ‚‚wโ‚€โ‚โ‚ƒwโ‚€โ‚โ‚„wโ‚€โ‚โ‚…โ”‚
aโ‚  โ”‚wโ‚โ‚€ wโ‚โ‚ wโ‚โ‚‚ wโ‚โ‚ƒ wโ‚โ‚„ wโ‚โ‚… wโ‚โ‚† wโ‚โ‚‡ wโ‚โ‚ˆ wโ‚โ‚‰ wโ‚โ‚โ‚€wโ‚โ‚โ‚wโ‚โ‚โ‚‚wโ‚โ‚โ‚ƒwโ‚โ‚โ‚„wโ‚โ‚โ‚…โ”‚
aโ‚‚  โ”‚wโ‚‚โ‚€ wโ‚‚โ‚ wโ‚‚โ‚‚ wโ‚‚โ‚ƒ wโ‚‚โ‚„ wโ‚‚โ‚… wโ‚‚โ‚† wโ‚‚โ‚‡ wโ‚‚โ‚ˆ wโ‚‚โ‚‰ wโ‚‚โ‚โ‚€wโ‚‚โ‚โ‚wโ‚‚โ‚โ‚‚wโ‚‚โ‚โ‚ƒwโ‚‚โ‚โ‚„wโ‚‚โ‚โ‚…โ”‚
... โ”‚...                                                              โ”‚
aโ‚โ‚… โ”‚wโ‚โ‚…โ‚€wโ‚โ‚…โ‚wโ‚โ‚…โ‚‚wโ‚โ‚…โ‚ƒwโ‚โ‚…โ‚„wโ‚โ‚…โ‚…wโ‚โ‚…โ‚†wโ‚โ‚…โ‚‡wโ‚โ‚…โ‚ˆwโ‚โ‚…โ‚‰wโ‚โ‚…โ‚โ‚€wโ‚โ‚…โ‚โ‚wโ‚โ‚…โ‚โ‚‚wโ‚โ‚…โ‚โ‚ƒwโ‚โ‚…โ‚โ‚„wโ‚โ‚…โ‚โ‚…โ”‚

Each position can influence every other position in the same channel.

Without Lateral Connections:

โ”Œโ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”
โ”‚ a โ”‚ โ”‚ b โ”‚ โ”‚ c โ”‚ โ”‚ d โ”‚  โ† Independent activations
โ””โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”˜    No communication between positions

With Lateral Connections:

โ”Œโ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”
โ”‚ a โ”‚โ†”โ”‚ b โ”‚โ†”โ”‚ c โ”‚โ†”โ”‚ d โ”‚  โ† Positions influence each other
โ””โ”€โ”ฌโ”€โ”˜ โ””โ”€โ”ฌโ”€โ”˜ โ””โ”€โ”ฌโ”€โ”˜ โ””โ”€โ”ฌโ”€โ”˜    
  โ†•     โ†•     โ†•     โ†•      Creates spatial competition
โ”Œโ”€โ”ดโ”€โ” โ”Œโ”€โ”ดโ”€โ” โ”Œโ”€โ”ดโ”€โ” โ”Œโ”€โ”ดโ”€โ”    and cooperation patterns
โ”‚ e โ”‚ โ”‚ f โ”‚ โ”‚ g โ”‚ โ”‚ h โ”‚
โ””โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”˜

Multi-Channel Lateral Structure

For a layer with C channels and Hร—W spatial dimensions:

Channel 0: [Hร—W positions] โ”€โ”€โ–บ [Nร—N lateral weights]
Channel 1: [Hร—W positions] โ”€โ”€โ–บ [Nร—N lateral weights]  
Channel 2: [Hร—W positions] โ”€โ”€โ–บ [Nร—N lateral weights]
...
Channel C: [Hร—W positions] โ”€โ”€โ–บ [Nร—N lateral weights]

where N = H ร— W

Total lateral weights: C ร— N ร— N
Example for 32ร—32โ†’16ร—16 layer (32 channels):
self.lateral_weights.shape = [32, 256, 256]
                             โ†‘    โ†‘    โ†‘
                             โ”‚    โ”‚    โ””โ”€ 256 "to" positions  
                             โ”‚    โ””โ”€โ”€โ”€โ”€โ”€โ”€ 256 "from" positions
                             โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ 32 independent channels

Hebbian Learning Step-by-Step

Step 1: Forward Pass

Input Batch: [B, C, H, W]
     โ”‚
     โ–ผ Conv2d + ReLU
Activations: [B, C, H', W']
     โ”‚
     โ–ผ L2 Normalize
Normalized: [B, C, H', W'] (unit norm per spatial map)

Step 2: Flatten for Lateral Processing

act_flat = activations.view(B, C, H'ร—W')
Result shape: [B, C, N] where N = H'ร—W'

Step 3: Hebbian Update (No Gradients!)

# Compute outer product for each spatial position
hebbian = torch.einsum("bni,bnj->nij", act_flat, act_flat)
#                       โ†‘     โ†‘     โ†‘
#                       โ”‚     โ”‚     โ””โ”€ Output: [N, N] per batch
#                       โ”‚     โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€ j-th position activation  
#                       โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ i-th position activation

# Average across batch and update weights
delta = 0.001 * hebbian.mean(dim=0)  # [N, N]
self.lateral_weights.data += delta   # Update each channel separately

Step 4: Apply Lateral Connections

# Matrix multiply: activations ร— lateral weights
lateral = torch.einsum("bci,cij->bcj", act_flat, self.lateral_weights)
#                      โ†‘     โ†‘     โ†‘
#                      โ”‚     โ”‚     โ””โ”€ Output activations per channel
#                      โ”‚     โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Channel-specific weights [Nร—N]
#                      โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Input activations [Cร—N]

Step 5: Inhibition (Competition)

lateral = lateral - lateral.mean(dim=(2,3), keepdim=True)
# Subtract mean โ†’ winner-take-all dynamics

Energy, Delta, Norm Logging

During training, the following values are logged per step:

  • Energy: mean squared activation across all units, i.e., %%\mathbb{E}[|a|^2]%%
  • Delta: mean absolute change in lateral weights
  • Norm: Frobenius norm of the lateral weight matrix, i.e., %%|W|_F%%

Example:

[LOG] Step 122: energy=0.1466, delta=0.000166, norm=155.7744

This gives insight into encoder dynamics: stable energy and delta values indicate convergence, while growing norm may suggest over-association.

Feature Extraction

Images are passed through a multi-layer encoder consisting of 4 HebbianEncoder layers. The final feature map is flattened to a 1D vector and stored.

features = model(images)
features = F.normalize(features, dim=1).cpu().numpy()

Hebbian Network Structure

The Hebbian encoder processes 64ร—64 images using a stack of convolutional layers with stride 2. Each layer halves the spatial resolution while increasing the channel count. Each Hebbian layer also includes lateral recurrent weights trained with Hebbian updates to reinforce co-activation patterns.

# For RGB images (Tiny ImageNet)
model = MultiLayerHebbian([
    (3, 16, (32, 32)),
    (16, 32, (16, 16)),
    (32, 64, (8, 8)),
    (64, 128, (4, 4))
])

# For RGBA images (Pokรฉmon sprites)
model = MultiLayerHebbian([
    (4, 16, (32, 32)),
    (16, 32, (16, 16)),
    (32, 64, (8, 8)),
    (64, 128, (4, 4))
])

Each tuple in the list specifies the parameters for a HebbianEncoder layer:

(in_channels, out_channels, spatial_shape)

This configuration maps as follows:

LayerInput ChannelsOutput ChannelsInput Spatial SizeOutput Spatial Size
13/4 (RGB/RGBA)1664ร—6432ร—32
2163232ร—3216ร—16
3326416ร—168ร—8
4641288ร—84ร—4

This structure results in a final feature tensor of shape (B, 128, 4, 4) per image, which is flattened to (B, 2048) and used for clustering and visualization.

The spatial argument in each HebbianEncoder is required to initialize the lateral weight matrix:

self.lateral_weights = torch.nn.Parameter(torch.zeros(C, N, N))

where N = H ร— W (the number of spatial positions per channel). These weights are updated using Hebbian learning:

$$ \Delta W_{ij} = \eta \cdot \langle a_i \cdot a_j \rangle $$

which in code becomes:

hebbian = torch.einsum("bni,bnj->nij", act_flat, act_flat)
delta = 0.001 * hebbian.mean(dim=0)
self.lateral_weights.data += delta

This configuration balances spatial compression and representational capacity, while the Hebbian lateral updates encourage neurons to specialize by detecting and reinforcing co-activation patterns.

Embedding and Clustering Visualization

UMAP is used to project feature vectors to 2D.

reducer = UMAP(n_components=2, random_state=42)
reduced = reducer.fit_transform(features)

Optionally, KMeans clustering is applied to the feature space:

kmeans = KMeans(n_clusters=6, random_state=42).fit(features)

These steps are primarily for visualization and evaluation. They allow us to inspect whether the encoder has organized inputs meaningfullyโ€”not as a training objective.

Layout and Plotting

Projected coordinates are normalized to fit inside a square canvas (e.g., 2500x2500 pixels). Margin padding ensures images are not clipped.

reduced -= reduced.min(axis=0)
reduced /= (reduced.max(axis=0) + 1e-8)
reduced *= (canvas_size - 2 * margin)
reduced += margin

Sprites are drawn onto the canvas using their corresponding (x, y) UMAP coordinates.

for (x, y), img in zip(reduced, sprites):
    pil = transforms.ToPILImage()(img).resize((sprite_size, sprite_size))
    canvas.paste(pil, (int(x), int(y)), mask=pil if has_alpha else None)

Code + Results

Tiny ImageNet (Hebbian)

Source

ย 

ย 

Pokรฉmon Full RGBA Hebbian

Source

ย 

ย 

Pokรฉmon Similarity

The first column contains randomly selected Pokรฉmon, and then the most-similar 5 Pokรฉmon are listed to the right.

Source

Hebbian Image Encoder (Single-Layer)

This was the first prototype. The first column contains randomly selected image (from Tiny Imagenet), and then the most-similar 5 images are listed to the right.

Source

Hebbian Deep Image Encoder - Tiny ImageNet - Source Code

import sys
import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import KMeans
from umap import UMAP

# === Config ===
IMAGE_DIR = "../data/tiny_imagenet/train"
SPRITE_SIZE = (64, 64)
BATCH_SIZE = 8
NUM_IMAGES = 1000
EPOCHS = 50
CLUSTERS = 4
EMBED_SIZE = 32  # thumbnail size
CANVAS_SIZE = 2500

# === Hebbian Network ===
class HebbianEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels, spatial):
        super().__init__()
        self.encode = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
        C = out_channels
        N = spatial[0] * spatial[1]
        self.lateral_weights = torch.nn.Parameter(torch.zeros(C, N, N))

    def forward(self, x, step=None):
        act = F.relu(self.encode(x))
        act = act / (act.norm(dim=(2, 3), keepdim=True) + 1e-6)
        B, C, H, W = act.shape
        act_flat = act.view(B, C, -1)

        with torch.no_grad():
            hebbian = torch.einsum("bni,bnj->nij", act_flat, act_flat)
            delta = 0.001 * hebbian.mean(dim=0)
            self.lateral_weights.data += delta
            self.lateral_weights.data.clamp_(-1.0, 1.0)

        lateral = torch.einsum("bci,cij->bcj", act_flat, self.lateral_weights)
        lateral = lateral.view(B, C, H, W)
        lateral = lateral - lateral.mean(dim=(2, 3), keepdim=True)
        act += lateral

        if step is not None:
            print(f"[LOG] Step {step}: energy={act.pow(2).mean():.4f}, delta={delta.abs().mean():.6f}, norm={self.lateral_weights.data.norm():.4f}")

        return act

class MultiLayerHebbian(torch.nn.Module):
    def __init__(self, layer_shapes):
        super().__init__()
        self.layers = torch.nn.ModuleList([
            HebbianEncoder(in_c, out_c, spatial) for (in_c, out_c, spatial) in layer_shapes
        ])

    def forward(self, x, step=None):
        for i, layer in enumerate(self.layers):
            x = layer(x, step=step if i == len(self.layers) - 1 else None)
        return x.view(x.size(0), -1).detach()

# === Load Dataset ===
def load_dataset():
    transform = transforms.Compose([
        transforms.Resize(SPRITE_SIZE),
        transforms.ToTensor()
    ])
    dataset = ImageFolder(IMAGE_DIR, transform=transform)
    subset = torch.utils.data.Subset(dataset, list(range(min(NUM_IMAGES, len(dataset)))))
    loader = DataLoader(subset, batch_size=BATCH_SIZE, shuffle=False)
    return subset, loader

# === Plot Utility ===
def plot_with_images(embeddings, images, title="Hebbian Clusters", size=32, canvas_size=2500):
    fig, ax = plt.subplots(figsize=(canvas_size / 100, canvas_size / 100), facecolor='black', dpi=100)
    ax.set_facecolor('black')
    ax.set_title(title, color='white')
    ax.set_xticks([])
    ax.set_yticks([])

    from matplotlib.offsetbox import OffsetImage, AnnotationBbox

    # Normalize coordinates to canvas
    margin = size * 2
    embeddings -= embeddings.min(axis=0)
    embeddings /= (embeddings.max(axis=0) + 1e-8)
    embeddings *= (canvas_size - 2 * margin)
    embeddings += margin

    for (x, y), img_tensor in zip(embeddings, images):
        img = transforms.ToPILImage()(img_tensor).resize((size, size), resample=Image.BILINEAR).convert("RGB")
        imbox = OffsetImage(img, zoom=1.5)  # zoom factor for visibility
        ab = AnnotationBbox(imbox, (x, y), frameon=False)
        ax.add_artist(ab)

    ax.set_xlim(0, canvas_size)
    ax.set_ylim(0, canvas_size)
    ax.invert_yaxis()
    plt.tight_layout()
    plt.savefig("tinyimagenet_hebbian_cluster_plot.png", facecolor='black')
    print("[SAVED] tinyimagenet_hebbian_cluster_plot.png")

# === Main ===
if __name__ == "__main__":
    dataset, dataloader = load_dataset()
    model = MultiLayerHebbian([
        (3, 16, (32, 32)),
        (16, 32, (16, 16)),
        (32, 64, (8, 8)),
        (64, 128, (4, 4))
    ])

    all_features = []
    for epoch in range(EPOCHS):
        for step, (batch, _) in enumerate(dataloader):
            z = model(batch, step=step) if epoch == EPOCHS - 1 else model(batch)
            if epoch == EPOCHS - 1:
                all_features.append(z)

    features = torch.cat(all_features, dim=0).cpu().numpy()
    features = np.nan_to_num(features)
    features /= (np.linalg.norm(features, axis=1, keepdims=True) + 1e-6)

    reducer = UMAP(n_components=2, random_state=42) #, min_dist=0.2)
    reduced = reducer.fit_transform(features)

    margin = EMBED_SIZE // 2
    reduced -= reduced.min(axis=0)
    reduced /= (reduced.max(axis=0) + 1e-8)
    reduced *= (CANVAS_SIZE - 2 * margin)
    reduced += margin

    all_images = [img for img, _ in dataset]
    plot_with_images(reduced, all_images)


Projects

Site

Tags