Calculation of _mscale during YARN RoPE scaling
#4
by
sszymczyk
- opened
I noticed that you calculate the cached sin and cos YARN RoPE values like this:
_mscale = float(
yarn_get_mscale(self.scaling_factor, self.mscale)
/ yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer(
"cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False
)
self.register_buffer(
"sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False
)
But in config.json self.mscale
(0.707) is equal to self.mscale_all_dim
(also 0.707), so yarn_get_mscale(self.scaling_factor, self.mscale)
will be equal to yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
, therefore _mscale will simply be 1.0. Is this intentional?
If anyone is interested I think I finally figured it out: https://github.com/ggerganov/llama.cpp/discussions/7416