Implementation of Denoising Diffusion Model#

diffusion generated

Jupyter Notebook

Data Preparation#

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
def prepare_dataset(batch_size):
    preprocessor = transforms.ToTensor()
    dataset = datasets.FashionMNIST(root="./data", download=True, transform=preprocessor)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return dataloader
batch_size = 128

dataloader = prepare_dataset(batch_size)
import matplotlib.pyplot as plt

def show_images(images, labels=None, cols=10, rows=3):
    figures = plt.figure(figsize=(cols, rows))
    idx = 0
    for r in range(rows):
        for c in range(cols):
          ax = figures.add_subplot(rows, cols, idx+1)
          plt.imshow(images[idx], cmap="gray")

          if labels is not None:
            ax.set_xlabel(labels[idx])
          ax.get_xaxis().set_ticks([])
          ax.get_xaxis().set_ticks([])
          idx += 1

    plt.tight_layout()
    plt.show()

def save_images(images, labels=None, cols=10, rows=3, save_path="images.png"):
    figures = plt.figure(figsize=(cols, rows))
    idx = 0
    for r in range(rows):
        for c in range(cols):
          ax = figures.add_subplot(rows, cols, idx+1)
          plt.imshow(images[idx], cmap="gray")

          if labels is not None:
              ax.set_xlabel(labels[idx])
          ax.get_xaxis().set_ticks([])
          ax.get_xaxis().set_ticks([])
          idx += 1

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

data_iter = iter(dataloader)
images, labels = next(data_iter)
show_images([img[0] for img in images])
../_images/f299882977768ec409ae409411842a87132687adf0a0b43a8e5e8d77198b4243.png

Model Preparation#

import torch
import torch.nn as nn
def time_embedding(time_steps, time_dim, device='cpu'):
    max_time = len(time_steps)
    embeddings = torch.zeros(max_time, time_dim, device=device)

    idx = torch.arange(0, time_dim, device=device)
    div_term = torch.exp(idx / (2 * time_dim) * torch.log(torch.tensor(10000.0)))

    for t in range(max_time):
        embeddings[t, 0::2] = torch.sin(time_steps[t] / div_term[::2])
        embeddings[t, 1::2] = torch.cos(time_steps[t]  / div_term[1::2])

    return embeddings

We use UNet (Ronneberger et al. 2015; MICCAI 2015) in this example. Unet

class ConvBlock(nn.Module):
    """
    A convolutional block with time embedding.
    """
    def __init__(self, in_channels, out_channels, time_dim):
        """
        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            time_dim (int): Dimension of the time embedding.
        """
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

        self.linear = nn.Sequential(
            nn.Linear(time_dim, in_channels*2),
            nn.ReLU(),
            nn.Linear(in_channels*2, in_channels)
        )

    def forward(self, x, t=None):
        """
        Args:
            x (torch.Tensor): Input tensor.
            t (torch.Tensor): Time embedding tensor.
        Returns:
            x (torch.Tensor): Output tensor.
        """
        if t is not None:
            t = self.linear(t)
            t = t.view(x.size(0), -1, 1, 1)
            x = x + t
        x = self.conv(x)

        return x


class UNetWithTime(nn.Module):
    """
    A U-Net model with time embedding.
    """
    def __init__(self, in_channels=1, out_channels=1, time_dim=512):
        """
        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            time_dim (int): Dimension of the time embedding.
        """
        super(UNetWithTime, self).__init__()
        self.time_dim = time_dim

        self.time_embedding = time_embedding
        # down sampling
        self.down_conv1 = ConvBlock(in_channels, 64, time_dim)
        self.down_conv2 = ConvBlock(64, 128, time_dim)
        # bottleneck
        self.bottleneck = ConvBlock(128, 256, time_dim)
        # up sampling
        self.up_conv2 = ConvBlock(256+128, 128, time_dim) # concat with down sampling
        self.up_conv1 = ConvBlock(128+64, 64, time_dim) # concat with down sampling
        # output
        self.out = nn.Conv2d(64, out_channels, kernel_size=1)
        # max pooling
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
        # up sampling
        self.up_sample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def forward(self, x, time_steps):
        """
        Args:
            x (torch.Tensor): Input tensor.
            time_steps (torch.Tensor): Time steps.
        Returns:
            x (torch.Tensor): Output tensor.
        """
        t = self.time_embedding(time_steps, self.time_dim, device=x.device)
        x1 = self.down_conv1(x, t)
        x = self.max_pool(x1)
        x2 = self.down_conv2(x, t)
        x = self.max_pool(x2)
        x = self.bottleneck(x, t)
        x = self.up_sample(x)
        x = self.up_conv2(torch.cat([x, x2], dim=1), t)
        x = self.up_sample(x)
        x = self.up_conv1(torch.cat([x, x1], dim=1), t)
        x = self.out(x)

        return x

Diffusion and Denoising#

class Diffuser:
    def __init__(self, max_steps, beta_start=0.001, beta_end=0.01, device='cpu'):
        """
        Args:
            max_steps (int): Maximum number of steps in diffusion step.
            beta_start (float): Initial value of beta.
            beta_end (float): Final value of beta.
        """
        self.device = device
        self.max_steps = max_steps
        self.betas = torch.linspace(beta_start, beta_end, max_steps, device=self.device)
        self.alphas = 1 - self.betas
        self.alphas_cumprod = self.alphas.cumprod(dim=0)

    def diffuse(self, x, t):
        """
        Args:
            x (torch.Tensor): Input tensor.
            t (torch.Tensor): Time embedding tensor.
        Returns:
            x (torch.Tensor): Output tensor.
        """
        noise = torch.randn_like(x, device=self.device)
        alpha_cumprod = self.alphas_cumprod[t-1].view(-1, 1, 1, 1)
        x_t = torch.sqrt(alpha_cumprod) * x + torch.sqrt(1 - alpha_cumprod) * noise

        return x_t, noise

    def denoise(self, model, x, t):
        """
        Args:
            x (torch.Tensor): Input tensor.
            t (torch.Tensor): Time embedding tensor.
        Returns:
            x (torch.Tensor): Output tensor.
        """
        _t = t-1
        alpha = self.alphas[_t].view(-1, 1, 1, 1)
        alpha_cumprod = self.alphas_cumprod[_t].view(-1, 1, 1, 1)
        alpha_cumprod_prev = self.alphas_cumprod[_t-1].view(-1, 1, 1, 1)

        model.eval()
        with torch.no_grad():
            pred_noise = model(x, t)
        model.train()

        noise = torch.randn_like(x, device=self.device)
        noise[t == 1] = 0

        mu = (x - ((1-alpha) / torch.sqrt(1-alpha_cumprod)) * pred_noise) / torch.sqrt(alpha)
        std = torch.sqrt((1-alpha) * (1-alpha_cumprod_prev) / (1-alpha_cumprod))
        return mu + noise * std

    def convert_to_image(self, x):
        x = x * 255
        x = x.clamp(0, 255)
        x = x.to(torch.uint8)
        return x

    def sample(self, model, sample_shape=(1, 1, 28, 28)):
        bsz = sample_shape[0]
        x = torch.randn(sample_shape, device=self.device)

        for i in range(self.max_steps, 0, -1):
            t = torch.tensor([i] * bsz, device=self.device, dtype=torch.long)
            x = self.denoise(model, x, t)

        x = torch.stack([self.convert_to_image(x[i]) for i in range(bsz)])
        return x

Training#

import os
from torch.optim import Adam
import torch.nn.functional as F
# Hyperparameters
in_channels = 1
img_size = 28
batch_size = 128
max_steps = 1000
time_dim = 100
lr = 2e-3
epochs = 5
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
   device = 'cpu'

# Prepare dataset
dataloader = prepare_dataset(batch_size)

# Initialize model and diffuser
model = UNetWithTime(in_channels=in_channels, time_dim=time_dim).to(device)
diffuser = Diffuser(max_steps=max_steps, device=device)
optimizer = Adam(model.parameters(), lr=lr)
def train(epochs, model, diffuser, optimizer, dataloader, max_steps, img_size, batch_size, device):
    # Training loop
    losses = []
    for epoch in range(epochs):
        loss_sum = 0.0
        cnt = 0

        images = diffuser.sample(model, (64, 1, img_size, img_size))
        print(len(images))
        show_images([img[0] for img in images.to('cpu')])

        #os.makedirs("data/FashionMNIST/generated", exist_ok=True)
        #save_images([img[0] for img in images.to('cpu')], save_path=f"data/FashionMNIST/generated/images_{epoch}.png")

        for images, labels in dataloader:
            optimizer.zero_grad()
            x = images.to(device)
            t = torch.randint(1, max_steps+1, (len(x),), device=device)

            x_noisy, noise = diffuser.diffuse(x, t)
            noise_pred = model(x_noisy, t)
            loss = F.mse_loss(noise, noise_pred)

            loss.backward()
            optimizer.step()

            loss_sum += loss.item()
            cnt += 1

        loss_avg = loss_sum / cnt
        losses.append(loss_avg)
        print(f'Epoch {epoch} | Loss: {loss_avg}')

    # save model
    #torch.save(model.state_dict(), "data/FashionMNIST/model.pth")
    return losses
losses = train(epochs, model, diffuser, optimizer, dataloader, max_steps, img_size, batch_size, device)
Hide code cell output
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[12], line 1
----> 1 losses = train(epochs, model, diffuser, optimizer, dataloader, max_steps, img_size, batch_size, device)

Cell In[11], line 8, in train(epochs, model, diffuser, optimizer, dataloader, max_steps, img_size, batch_size, device)
      5 loss_sum = 0.0
      6 cnt = 0
----> 8 images = diffuser.sample(model, (64, 1, img_size, img_size))
      9 print(len(images))
     10 show_images([img[0] for img in images.to('cpu')])

Cell In[8], line 66, in Diffuser.sample(self, model, sample_shape)
     64 for i in range(self.max_steps, 0, -1):
     65     t = torch.tensor([i] * bsz, device=self.device, dtype=torch.long)
---> 66     x = self.denoise(model, x, t)
     68 x = torch.stack([self.convert_to_image(x[i]) for i in range(bsz)])
     69 return x

Cell In[8], line 44, in Diffuser.denoise(self, model, x, t)
     42 model.eval()
     43 with torch.no_grad():
---> 44     pred_noise = model(x, t)
     45 model.train()
     47 noise = torch.randn_like(x, device=self.device)

File ~/.cache/pypoetry/virtualenvs/loem-notes-kagQPLM5-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/.cache/pypoetry/virtualenvs/loem-notes-kagQPLM5-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

Cell In[7], line 92, in UNetWithTime.forward(self, x, time_steps)
     90 x = self.up_conv2(torch.cat([x, x2], dim=1), t)
     91 x = self.up_sample(x)
---> 92 x = self.up_conv1(torch.cat([x, x1], dim=1), t)
     93 x = self.out(x)
     95 return x

File ~/.cache/pypoetry/virtualenvs/loem-notes-kagQPLM5-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/.cache/pypoetry/virtualenvs/loem-notes-kagQPLM5-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

Cell In[7], line 40, in ConvBlock.forward(self, x, t)
     38     t = t.view(x.size(0), -1, 1, 1)
     39     x = x + t
---> 40 x = self.conv(x)
     42 return x

File ~/.cache/pypoetry/virtualenvs/loem-notes-kagQPLM5-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/.cache/pypoetry/virtualenvs/loem-notes-kagQPLM5-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/.cache/pypoetry/virtualenvs/loem-notes-kagQPLM5-py3.11/lib/python3.11/site-packages/torch/nn/modules/container.py:217, in Sequential.forward(self, input)
    215 def forward(self, input):
    216     for module in self:
--> 217         input = module(input)
    218     return input

File ~/.cache/pypoetry/virtualenvs/loem-notes-kagQPLM5-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/.cache/pypoetry/virtualenvs/loem-notes-kagQPLM5-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/.cache/pypoetry/virtualenvs/loem-notes-kagQPLM5-py3.11/lib/python3.11/site-packages/torch/nn/modules/conv.py:460, in Conv2d.forward(self, input)
    459 def forward(self, input: Tensor) -> Tensor:
--> 460     return self._conv_forward(input, self.weight, self.bias)

File ~/.cache/pypoetry/virtualenvs/loem-notes-kagQPLM5-py3.11/lib/python3.11/site-packages/torch/nn/modules/conv.py:456, in Conv2d._conv_forward(self, input, weight, bias)
    452 if self.padding_mode != 'zeros':
    453     return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
    454                     weight, bias, self.stride,
    455                     _pair(0), self.dilation, self.groups)
--> 456 return F.conv2d(input, weight, bias, self.stride,
    457                 self.padding, self.dilation, self.groups)

KeyboardInterrupt: 
Hide code cell source
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()
plt.savefig("data/FashionMNIST/generated/loss.png")
plt.close()

images = diffuser.sample(model, (batch_size, 1, img_size, img_size))
show_images([img[0] for img in images.to('cpu')])
../_images/5bd5cd4a0b95f213c917b84b56eb16bf7d071dcc24860cb62bfa2cbbab6a84a5.png ../_images/73cd376bcc369a8ca0622516c1fbb7c5b269b09b0109d1c934447135f653b210.png

Conditional Diffusion Model#

In previous section, we have formulated generative model as probabilistic distribution \(p(x)\) of observed data \(x\). However, in reality, we may want to control the generative data with specific conditions \(\left(p(x|y)\right)\) rather than in a random way. For example, we may want to generate an image of a cat in a specific pose, or a person in a particular hair style. Conditional diffusion models excel at this type of task, allowing us to steer the generative process towards desired outcomes.

In this section, we’ll delve deeper into the inner workings of conditional diffusion models. We’ll explore how they leverage the concept of score-based models and diffusion processes to achieve their impressive results. We’ll also discuss some of the challenges and limitations of these models, and how researchers are working to address them.

To take conditions into account, we consider conditional diffusion model as the following conditional probabilistic distribution.

\[\begin{split} \begin{align*} p_\theta(x|y) &= \int p_\theta(x|x_1, y) \cdots p_\theta(x_{T-1}|x_T, y)p_\theta(x_T)dx_1\cdots dx_T \\ \end{align*} \end{split}\]

In previous chapter, we have discussed the training method of denoising diffusion model with original data restoring or noise estimation with neural network decoder. In noise estimation, we use the noisy data and the corresponding time step as inputs of the neural network decoder. To extend this model to deal with conditions \(y\), we need to revise the architecture of the decoder to consider the feature of the conditions during training.

Conditional Diffusion

There many ways for implementation. For example, a simple way is to consider the representation of conditions (this can be text, image, label, etc) as embedding vectors similar to the case of time step embedding.

Embedding

class ConditionalUNetWithTime(nn.Module):
    """
    A U-Net model with time embedding.
    """
    def __init__(self, in_channels=1, out_channels=1, time_dim=512, labels_num=None):
        """
        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            time_dim (int): Dimension of the time embedding.
            labels_num (int): Number of labels.
        """
        super(ConditionalUNetWithTime, self).__init__()
        self.time_dim = time_dim

        self.time_embedding = time_embedding
        # down sampling
        self.down_conv1 = ConvBlock(in_channels, 64, time_dim)
        self.down_conv2 = ConvBlock(64, 128, time_dim)
        # bottleneck
        self.bottleneck = ConvBlock(128, 256, time_dim)
        # up sampling
        self.up_conv2 = ConvBlock(256+128, 128, time_dim) # concat with down sampling
        self.up_conv1 = ConvBlock(128+64, 64, time_dim) # concat with down sampling
        # output
        self.out = nn.Conv2d(64, out_channels, kernel_size=1)
        # max pooling
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
        # up sampling
        self.up_sample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        # label embedding
        if labels is not None:
            self.label_embedding = nn.Embedding(labels_num, time_dim)

    def forward(self, x, time_steps, labels=None):
        """
        Args:
            x (torch.Tensor): Input tensor.
            time_steps (torch.Tensor): Time steps.
            labels (torch.Tensor): Labels.
        Returns:
            x (torch.Tensor): Output tensor.
        """
        t = self.time_embedding(time_steps, self.time_dim, device=x.device)

        if labels is not None:
            t += self.label_embedding(labels.to(x.device))

        x1 = self.down_conv1(x, t)
        x = self.max_pool(x1)
        x2 = self.down_conv2(x, t)
        x = self.max_pool(x2)
        x = self.bottleneck(x, t)
        x = self.up_sample(x)
        x = self.up_conv2(torch.cat([x, x2], dim=1), t)
        x = self.up_sample(x)
        x = self.up_conv1(torch.cat([x, x1], dim=1), t)
        x = self.out(x)

        return x
class Diffuser:
    def __init__(self, max_steps, beta_start=0.001, beta_end=0.01, device='cpu'):
        """
        Args:
            max_steps (int): Maximum number of steps in diffusion step.
            beta_start (float): Initial value of beta.
            beta_end (float): Final value of beta.
        """
        self.device = device
        self.max_steps = max_steps
        self.betas = torch.linspace(beta_start, beta_end, max_steps, device=self.device)
        self.alphas = 1 - self.betas
        self.alphas_cumprod = self.alphas.cumprod(dim=0)

    def diffuse(self, x, t):
        """
        Args:
            x (torch.Tensor): Input tensor.
            t (torch.Tensor): Time embedding tensor.
        Returns:
            x (torch.Tensor): Output tensor.
        """
        noise = torch.randn_like(x, device=self.device)
        alpha_cumprod = self.alphas_cumprod[t-1].view(-1, 1, 1, 1)
        x_t = torch.sqrt(alpha_cumprod) * x + torch.sqrt(1 - alpha_cumprod) * noise

        return x_t, noise

    def denoise(self, model, x, t, labels):
        """
        Args:
            x (torch.Tensor): Input tensor.
            t (torch.Tensor): Time embedding tensor.
            labels (torch.Tensor): Label tensor.
        Returns:
            x (torch.Tensor): Output tensor.
        """
        _t = t-1
        alpha = self.alphas[_t].view(-1, 1, 1, 1)
        alpha_cumprod = self.alphas_cumprod[_t].view(-1, 1, 1, 1)
        alpha_cumprod_prev = self.alphas_cumprod[_t-1].view(-1, 1, 1, 1)

        model.eval()
        with torch.no_grad():
            pred_noise = model(x, t, labels)
        model.train()

        noise = torch.randn_like(x, device=self.device)
        noise[t == 1] = 0

        mu = (x - ((1-alpha) / torch.sqrt(1-alpha_cumprod)) * pred_noise) / torch.sqrt(alpha)
        std = torch.sqrt((1-alpha) * (1-alpha_cumprod_prev) / (1-alpha_cumprod))
        return mu + noise * std

    def convert_to_image(self, x):
        x = x * 255
        x = x.clamp(0, 255)
        x = x.to(torch.uint8)
        return x

    def sample(self, model, sample_shape=(1, 1, 28, 28), labels=None):
        bsz = sample_shape[0]
        x = torch.randn(sample_shape, device=self.device)

        if labels is not None:
            assert bsz == len(labels), f"Batch size ({bsz})must be equal to the number of labels ({len(labels)})."
            # convert list of int to torch with device
            labels = torch.as_tensor(labels).to(self.device)


        for i in range(self.max_steps, 0, -1):
            t = torch.tensor([i] * bsz, device=self.device, dtype=torch.long)
            x = self.denoise(model, x, t, labels)

        x = torch.stack([self.convert_to_image(x[i]) for i in range(bsz)])
        return x, labels
# Hyperparameters
in_channels = 1
img_size = 28
batch_size = 128
max_steps = 1000
time_dim = 100
lr = 2e-3
epochs = 5
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
   device = 'cpu'

# Prepare dataset
dataloader = prepare_dataset(batch_size)

# Initialize model and diffuser
model = ConditionalUNetWithTime(in_channels=in_channels, time_dim=time_dim, labels_num=10).to(device)
diffuser = Diffuser(max_steps=max_steps, device=device)
optimizer = Adam(model.parameters(), lr=lr)
label2text = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
def train_conditional_diffusion(epochs, model, diffuser, optimizer, dataloader, max_steps, img_size, batch_size, device):
    # Training loop
    losses = []
    for epoch in range(epochs):
        loss_sum = 0.0
        cnt = 0

        images, labels = diffuser.sample(model, (10, 1, img_size, img_size), labels=list(range(10)))
        show_images([img[0] for img in images.to('cpu')], labels=[label2text[i] for i in range(10)], cols=5, rows=2)

        #os.makedirs("data/FashionMNIST/generated", exist_ok=True)
        #save_images([img[0] for img in images.to('cpu')], save_path=f"data/FashionMNIST/generated/images_{epoch}.png")

        for images, labels in dataloader:
            optimizer.zero_grad()
            x = images.to(device)
            t = torch.randint(1, max_steps+1, (len(x),), device=device)

            x_noisy, noise = diffuser.diffuse(x, t)
            noise_pred = model(x_noisy, t, labels)
            loss = F.mse_loss(noise, noise_pred)

            loss.backward()
            optimizer.step()

            loss_sum += loss.item()
            cnt += 1

        loss_avg = loss_sum / cnt
        losses.append(loss_avg)
        print(f'Epoch {epoch} | Loss: {loss_avg}')

    # save model
    #torch.save(model.state_dict(), "data/FashionMNIST/model.pth")
    return losses
losses = train_conditional_diffusion(epochs, model, diffuser, optimizer, dataloader, max_steps, img_size, batch_size, device)
Hide code cell output
../_images/3f820cca87c110c69c7feac47413df52fe8d53d034b80cde6f639e53e40d9f8d.png
Epoch 0 | Loss: 0.05732659705634564
../_images/2f0efbb840952780cc489cbf840dac2196e7b90f6cb1d34ca71d42a47559e9c8.png
Epoch 1 | Loss: 0.034125066118072596
../_images/16cb8d4b3139ad78159b013cffe086193ddb3b3e389fd5c68f8cf86c11751ea5.png
Epoch 2 | Loss: 0.030119409371636062
../_images/b554cbb5fd2b4f2bb7f036ade322758e61f01f9bf13c912bc14510f9f571b26c.png
Epoch 3 | Loss: 0.027875443813261954
../_images/786f7ac21445b84a7b503316a7226486fa759261968d7f8fe911e299113855df.png
Epoch 4 | Loss: 0.02637626212447691
Hide code cell source
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()
plt.savefig("data/FashionMNIST/generated/conditionalUnet_loss.png")
plt.close()

images, labels = diffuser.sample(model, (10, 1, img_size, img_size), labels=list(range(10)))
show_images([img[0] for img in images.to('cpu')], labels=[label2text[i] for i in labels], cols=5, rows=2)
../_images/1be5d3ffc75ada80e3de1d0830ae2f37dfb7b6fae8427c7d2f44a3ce0f535631.png ../_images/c346d1fe92e37b60d2240091b110c96f7051894f5634f076a93db762ca898809.png