sippycoder commited on
Commit
44fdd22
1 Parent(s): 2ca4391

initial commit

Browse files
Files changed (2) hide show
  1. configuration_nucleus.py +89 -0
  2. modeling_nucleus.py +155 -0
configuration_nucleus.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This config is based on LLaMA.
2
+ """ Nucleus model configuration"""
3
+
4
+ from transformers import LlamaConfig
5
+ from transformers.utils import logging
6
+
7
+
8
+ logger = logging.get_logger(__name__)
9
+
10
+ NUCLEUS_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
11
+
12
+
13
+ class NucleusConfig(LlamaConfig):
14
+ model_type = "nulceus"
15
+ keys_to_ignore_at_inference = ["past_key_values"]
16
+
17
+ def __init__(
18
+ self,
19
+ vocab_size=32000,
20
+ hidden_size=4096,
21
+ intermediate_size=11008,
22
+ num_hidden_layers=32,
23
+ num_attention_heads=32,
24
+ num_key_value_heads=None,
25
+ hidden_act="silu",
26
+ max_position_embeddings=2048,
27
+ initializer_range=0.02,
28
+ rms_norm_eps=1e-6,
29
+ use_cache=True,
30
+ pad_token_id=None,
31
+ bos_token_id=1,
32
+ eos_token_id=2,
33
+ pretraining_tp=1,
34
+ tie_word_embeddings=False,
35
+ rope_theta=10000.0,
36
+ rope_scaling=None,
37
+ attention_bias=False,
38
+ **kwargs,
39
+ ):
40
+ self.vocab_size = vocab_size
41
+ self.max_position_embeddings = max_position_embeddings
42
+ self.hidden_size = hidden_size
43
+ self.intermediate_size = intermediate_size
44
+ self.num_hidden_layers = num_hidden_layers
45
+ self.num_attention_heads = num_attention_heads
46
+
47
+ # for backward compatibility
48
+ if num_key_value_heads is None:
49
+ num_key_value_heads = num_attention_heads
50
+
51
+ self.num_key_value_heads = num_key_value_heads
52
+ self.hidden_act = hidden_act
53
+ self.initializer_range = initializer_range
54
+ self.rms_norm_eps = rms_norm_eps
55
+ self.pretraining_tp = pretraining_tp
56
+ self.use_cache = use_cache
57
+ self.rope_theta = rope_theta
58
+ self.rope_scaling = rope_scaling
59
+ self._rope_scaling_validation()
60
+ self.attention_bias = attention_bias
61
+
62
+ super().__init__(
63
+ pad_token_id=pad_token_id,
64
+ bos_token_id=bos_token_id,
65
+ eos_token_id=eos_token_id,
66
+ tie_word_embeddings=tie_word_embeddings,
67
+ **kwargs,
68
+ )
69
+
70
+ def _rope_scaling_validation(self):
71
+ """
72
+ Validate the `rope_scaling` configuration.
73
+ """
74
+ if self.rope_scaling is None:
75
+ return
76
+
77
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
78
+ raise ValueError(
79
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
80
+ f"got {self.rope_scaling}"
81
+ )
82
+ rope_scaling_type = self.rope_scaling.get("type", None)
83
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
84
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
85
+ raise ValueError(
86
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
87
+ )
88
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
89
+ raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
modeling_nucleus.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is based on LLaMA
2
+ """ PyTorch Nucleus model."""
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torch.utils.checkpoint
8
+ from torch import nn
9
+ from torch.nn import CrossEntropyLoss
10
+
11
+ from transformers.modeling_outputs import CausalLMOutputWithPast
12
+ from .configuration_nucleus import NucleusConfig
13
+
14
+ from transformers import (
15
+ LlamaPreTrainedModel,
16
+ LlamaModel
17
+ )
18
+
19
+
20
+ class NucleusForCausalLM(LlamaPreTrainedModel):
21
+ _tied_weights_keys = ["lm_head.weight"]
22
+
23
+ def __init__(self, config: NucleusConfig):
24
+ super().__init__(config)
25
+ self.model = LlamaModel(config)
26
+ self.vocab_size = config.vocab_size
27
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
28
+
29
+ # Initialize weights and apply final processing
30
+ self.post_init()
31
+
32
+ def get_input_embeddings(self):
33
+ return self.model.embed_tokens
34
+
35
+ def set_input_embeddings(self, value):
36
+ self.model.embed_tokens = value
37
+
38
+ def get_output_embeddings(self):
39
+ return self.lm_head
40
+
41
+ def set_output_embeddings(self, new_embeddings):
42
+ self.lm_head = new_embeddings
43
+
44
+ def set_decoder(self, decoder):
45
+ self.model = decoder
46
+
47
+ def get_decoder(self):
48
+ return self.model
49
+
50
+ def forward(
51
+ self,
52
+ input_ids: torch.LongTensor = None,
53
+ attention_mask: Optional[torch.Tensor] = None,
54
+ position_ids: Optional[torch.LongTensor] = None,
55
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
56
+ inputs_embeds: Optional[torch.FloatTensor] = None,
57
+ labels: Optional[torch.LongTensor] = None,
58
+ use_cache: Optional[bool] = None,
59
+ output_attentions: Optional[bool] = None,
60
+ output_hidden_states: Optional[bool] = None,
61
+ return_dict: Optional[bool] = None,
62
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
63
+
64
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
65
+ output_hidden_states = (
66
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
67
+ )
68
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
69
+
70
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
71
+ outputs = self.model(
72
+ input_ids=input_ids,
73
+ attention_mask=attention_mask,
74
+ position_ids=position_ids,
75
+ past_key_values=past_key_values,
76
+ inputs_embeds=inputs_embeds,
77
+ use_cache=use_cache,
78
+ output_attentions=output_attentions,
79
+ output_hidden_states=output_hidden_states,
80
+ return_dict=return_dict,
81
+ )
82
+
83
+ hidden_states = outputs[0]
84
+ if self.config.pretraining_tp > 1:
85
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
86
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
87
+ logits = torch.cat(logits, dim=-1)
88
+ else:
89
+ logits = self.lm_head(hidden_states)
90
+ logits = logits.float()
91
+
92
+ loss = None
93
+ if labels is not None:
94
+ # Shift so that tokens < n predict n
95
+ shift_logits = logits[..., :-1, :].contiguous()
96
+ shift_labels = labels[..., 1:].contiguous()
97
+ # Flatten the tokens
98
+ loss_fct = CrossEntropyLoss()
99
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
100
+ shift_labels = shift_labels.view(-1)
101
+ # Enable model parallelism
102
+ shift_labels = shift_labels.to(shift_logits.device)
103
+ loss = loss_fct(shift_logits, shift_labels)
104
+
105
+ if not return_dict:
106
+ output = (logits,) + outputs[1:]
107
+ return (loss,) + output if loss is not None else output
108
+
109
+ return CausalLMOutputWithPast(
110
+ loss=loss,
111
+ logits=logits,
112
+ past_key_values=outputs.past_key_values,
113
+ hidden_states=outputs.hidden_states,
114
+ attentions=outputs.attentions,
115
+ )
116
+
117
+ def prepare_inputs_for_generation(
118
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
119
+ ):
120
+ if past_key_values:
121
+ input_ids = input_ids[:, -1:]
122
+
123
+ position_ids = kwargs.get("position_ids", None)
124
+ if attention_mask is not None and position_ids is None:
125
+ # create position_ids on the fly for batch generation
126
+ position_ids = attention_mask.long().cumsum(-1) - 1
127
+ position_ids.masked_fill_(attention_mask == 0, 1)
128
+ if past_key_values:
129
+ position_ids = position_ids[:, -1].unsqueeze(-1)
130
+
131
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
132
+ if inputs_embeds is not None and past_key_values is None:
133
+ model_inputs = {"inputs_embeds": inputs_embeds}
134
+ else:
135
+ model_inputs = {"input_ids": input_ids}
136
+
137
+ model_inputs.update(
138
+ {
139
+ "position_ids": position_ids,
140
+ "past_key_values": past_key_values,
141
+ "use_cache": kwargs.get("use_cache"),
142
+ "attention_mask": attention_mask,
143
+ }
144
+ )
145
+ return model_inputs
146
+
147
+ @staticmethod
148
+ def _reorder_cache(past_key_values, beam_idx):
149
+ reordered_past = ()
150
+ for layer_past in past_key_values:
151
+ reordered_past += (
152
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
153
+ )
154
+ return reordered_past
155
+