File size: 5,094 Bytes
52b380b
 
 
 
 
 
 
 
 
 
 
 
 
 
4ddeacf
 
52b380b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
---
license: mit
language:
- ru
- en
---

# GigaChat-20B-A3B-base

Большая языковая модель, основанна на MoE архитектуре, обучена специально под русский язык **с нуля**.
Всего у модели 20 миллиардов параметров, но во время инференса задействовано только 3 миллиарда. Контекст модели =131k токенов.

Больше подробностей в [хабр статье](https://habr.com/en/companies/sberdevices/articles/865996/).

Upd. Перезалили веса в `.safetensors`

## Архитектура модели

GigaChat-20B-A3B состоит из следующих деталей:

- Fine-grained Experts + Shared Experts
- Grouped Query Attention
- Rotary Position Embeddings
- RMSNorm
- SwiGLU в MLP

Важно то, что в реализации MoE некоторые эксперты вызываются в зависимости от контекста, а другие используются всегда.

## Бенчмарки

Общие английские метрики. Для замера использовался популярный открытый репозиторий [LM Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness).

| Bench                         | T-lite-0.1<br>(llama 3.1 8B based)| Llama-3.1-8B | GigaChat-20B-A3B-base | Gemma-9B  |
| ----------------------------- | ---------- | ------------ | --------------------- | --------- |
| MMLU (5-shot)                 | 62.56      | 65.21        | 63.02                 | 70.6      |
| MMLU-pro (5-shot)             | 32.19      | 35.7         | 31.41                 | 42.85     |
| MMLU-ru (5-shot)              | 55.51      | 54.1         | 58.38                 | 62.57     |
| BBH (3-shot)                  | 62.36      | 62.79        | 53.54                 | 70.48     |
| ARC-C (25-shot)               | 58.19      | 54.69        | 61.69                 | 68.34     |
| TruthfulQA (0-shot) (rougeL)  | 46.51      | 34.52        | 31.82                 | 41.49     |
| Winogrande (5-shot)           | 78.45      | 77.43        | 75.85                 | 79.4      |
| Hellaswag (10-shot)           | 82.21      | 81.85        | 81.91                 | 82.5      |
| GPQA (5-shot)                 | 0.25       | 23.44        | 25.22                 | 30.36     |
| MATH (4-shot)                 | 12.9       | 14.04        | 15.04                 | 20.06     |
| GSM8K (4-shot) (strict-match) | 67.93      | 51.4         | 59.06                 | 68.99     |
| HumanEval                     | 16.46      | 25.61        | 32.32                 | 37.2      |
| **AVG**                       | **47.96**  | **48.4**     | **49.11**             | **56.24** |


## Requirements

* ```transformers>=4.47```


## Пример использования через transformers

```python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

model_name = "ai-sage/GigaChat-20B-A3B-base"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto")
model.generation_config = GenerationConfig.from_pretrained(model_name)

messages = (
    "Ниже я написал подробное доказательство теоремы о неподвижной точке:"
)
input_tensor = tokenizer(messages, return_tensors="pt").input_ids
outputs = model.generate(input_tensor.to(model.device))

result = tokenizer.decode(outputs[0][input_tensor.shape[1]:], skip_special_tokens=False)
print(result)
```

## Пример использования через vLLM

```python
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

model_name = "ai-sage/GigaChat-20B-A3B-base"
llm = LLM(model=model_name, tokenizer=model_name, trust_remote_code=True)
sampling_params = SamplingParams(
    temperature=0.3,
    max_tokens=8192,
    stop_token_ids=[tokenizer.eos_token_id]
)

messages = (
    "Ниже я написал подробное доказательство теоремы о неподвижной точке:"
)
outputs = llm.generate(messages, sampling_params=sampling_params)
generated_text = [output.outputs[0].text for output in outputs]
print(generated_text)
```

## Скорость генерации

| Model | Total params (B) | Active params (B) | Req/s | Output Token/s | Total Token/s |
|---------|-----------------|------------------|--------|----------------|----------------|
| Qwen/Qwen1.5-MoE-A2.7B-Chat | 14 | 2,7 | 0,62 | 156,43 | 291,17 |
| deepseek-ai/deepseek-moe-16b-chat | 16 | 2,8 | 0,59 | 149,53 | 285,39 |
| **GigaChat-20B-A3B** | 20 | 3,3 | 0,55 | 137,43 | 259,27 |
| Qwen/Qwen2.5-3B-Instruct | 3 | 3 | 0,54 | 135,10 | 251,44 |
| meta-llama/Meta-Llama-3-8B-Instruct | 8 | 8 | 0,35 | 83,26 | 157,32 |
| google/gemma-2-9b-it | 9 | 9 | 0,27 | 54,87 | 113,69 |