zjowowen's picture
init space
079c32c
raw
history blame
589 Bytes
from typing import Dict
import torch
import torch.nn as nn
from ding.utils import MODEL_REGISTRY
from .q_learning import DQN
@MODEL_REGISTRY.register('sqn')
class SQN(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super(SQN, self).__init__()
self.q0 = DQN(*args, **kwargs)
self.q1 = DQN(*args, **kwargs)
def forward(self, data: torch.Tensor) -> Dict:
output0 = self.q0(data)
output1 = self.q1(data)
return {
'q_value': [output0['logit'], output1['logit']],
'logit': output0['logit'],
}