In [ ]:
import torch
from torch import nn
import torch.nn.functional as F
from functools import partial

def batch_norm(x):
    mean = x.mean(0, keepdim=True)
    var = x.var(0, unbiased=False, keepdim=True)
    x_norm = (x - mean) / (var + 1e-5).sqrt()
    return x_norm

def layer_norm(x):
    mean = x.mean(1, keepdim=True)
    var = x.var(1, unbiased=False, keepdim=True)
    x_norm = (x - mean) / (var + 1e-5).sqrt()
    return x_norm

def group_norm(x, num_groups):
    N, C = x.shape
    x = x.view(N, num_groups, -1)
    mean = x.mean(-1, keepdim=True)
    var = x.var(-1, unbiased=False, keepdim=True)
    x_norm = (x - mean) / (var + 1e-5).sqrt()
    x_norm = x_norm.view(N, C)
    return x_norm

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, norm_func):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.norm_func = norm_func
        self.linear2 = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, x):
        x = self.linear1(x)
        x = self.norm_func(x)
        x = F.relu(x)
        x = self.linear2(x)
        return x

# Create a random tensor with size (batch_size, input_dim)
x = torch.randn(32,  100)

# Create the MLP models with batch norm, layer norm, and group norm
model_bn = MLP(100, 64, 10, batch_norm)
model_ln = MLP(100, 64, 10, layer_norm)
model_gn = MLP(100, 64, 10, partial(group_norm, num_groups=4))

# Pass the input tensor through the models
output_bn = model_bn(x)
output_ln = model_ln(x)
output_gn = model_gn(x)

# Print the outputs
print("Output with batch norm:\n", output_bn.shape)
print("\nOutput with layer norm:\n", output_ln.shape)
print("\nOutput with group norm:\n", output_gn.shape) 
In [ ]: