import torch from torch import nn class Discriminator(nn.Module): def __init__(self, input_dim=2, hidden_dim=256, hidden_layers=6): super().__init__() layers = [nn.Linear(input_dim, hidden_dim), nn.ReLU()] for _ in range(hidden_layers - 1): layers.append(nn.Linear(hidden_dim, hidden_dim)) layers.append(nn.ReLU()) layers.append(nn.Linear(hidden_dim, 1)) self.network = nn.Sequential(*layers) def forward(self, x): return self.network(x)