|
|
|
|
|
|
|
import torch |
|
|
|
try: |
|
import curope as _kernels |
|
except ModuleNotFoundError: |
|
from . import curope as _kernels |
|
|
|
|
|
class cuRoPE2D_func (torch.autograd.Function): |
|
|
|
@staticmethod |
|
def forward(ctx, tokens, positions, base, F0=1): |
|
ctx.save_for_backward(positions) |
|
ctx.saved_base = base |
|
ctx.saved_F0 = F0 |
|
|
|
_kernels.rope_2d( tokens, positions, base, F0 ) |
|
ctx.mark_dirty(tokens) |
|
return tokens |
|
|
|
@staticmethod |
|
def backward(ctx, grad_res): |
|
positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0 |
|
_kernels.rope_2d( grad_res, positions, base, -F0 ) |
|
ctx.mark_dirty(grad_res) |
|
return grad_res, None, None, None |
|
|
|
|
|
class cuRoPE2D(torch.nn.Module): |
|
def __init__(self, freq=100.0, F0=1.0): |
|
super().__init__() |
|
self.base = freq |
|
self.F0 = F0 |
|
|
|
def forward(self, tokens, positions): |
|
cuRoPE2D_func.apply( tokens.transpose(1,2), positions, self.base, self.F0 ) |
|
return tokens |