Fill-Mask
Transformers
PyTorch
Safetensors
English
nomic_bert
custom_code

Update modeling_hf_nomic_bert.py

#11
by zpn - opened
Files changed (1) hide show
  1. modeling_hf_nomic_bert.py +244 -1
modeling_hf_nomic_bert.py CHANGED
@@ -22,12 +22,22 @@ from einops import rearrange, repeat
22
  from safetensors.torch import load_file as safe_load_file
23
  from torch.nn.modules.utils import _pair
24
  from transformers import GPT2Config, PreTrainedModel, ViTConfig, ViTModel
25
- from transformers.modeling_outputs import BaseModelOutputWithPast
26
  from transformers.models.bert.modeling_bert import (
27
  BaseModelOutputWithPoolingAndCrossAttentions,
28
  MaskedLMOutput,
29
  SequenceClassifierOutput,
30
  )
 
 
 
 
 
 
 
 
 
 
 
31
  from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
32
  from transformers.utils.hub import cached_file, get_checkpoint_shard_files
33
 
@@ -1853,6 +1863,239 @@ class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
1853
  attentions=outputs.attentions,
1854
  )
1855
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1856
 
1857
  def hf_vit_config_to_vit_config(vit_config: ViTConfig) -> GPT2Config:
1858
  return GPT2Config(
 
22
  from safetensors.torch import load_file as safe_load_file
23
  from torch.nn.modules.utils import _pair
24
  from transformers import GPT2Config, PreTrainedModel, ViTConfig, ViTModel
 
25
  from transformers.models.bert.modeling_bert import (
26
  BaseModelOutputWithPoolingAndCrossAttentions,
27
  MaskedLMOutput,
28
  SequenceClassifierOutput,
29
  )
30
+ from transformers.modeling_outputs import (
31
+ BaseModelOutput,
32
+ BaseModelOutputWithPast,
33
+ BaseModelOutputWithPooling,
34
+ MaskedLMOutput,
35
+ MultipleChoiceModelOutput,
36
+ QuestionAnsweringModelOutput,
37
+ SequenceClassifierOutput,
38
+ ModelOutput,
39
+ TokenClassifierOutput,
40
+ )
41
  from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
42
  from transformers.utils.hub import cached_file, get_checkpoint_shard_files
43
 
 
1863
  attentions=outputs.attentions,
1864
  )
1865
 
1866
+ class NomicBertForMultipleChoice(NomicBertPreTrainedModel):
1867
+ def __init__(self, config):
1868
+ super().__init__(config)
1869
+
1870
+ self.new = NomicBertModel(config, add_pooling_layer=True)
1871
+ classifier_dropout = (
1872
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1873
+ )
1874
+ self.dropout = nn.Dropout(classifier_dropout)
1875
+ self.classifier = nn.Linear(config.hidden_size, 1)
1876
+
1877
+ # Initialize weights and apply final processing
1878
+ self.post_init()
1879
+
1880
+ def forward(
1881
+ self,
1882
+ input_ids: Optional[torch.Tensor] = None,
1883
+ attention_mask: Optional[torch.Tensor] = None,
1884
+ token_type_ids: Optional[torch.Tensor] = None,
1885
+ position_ids: Optional[torch.Tensor] = None,
1886
+ head_mask: Optional[torch.Tensor] = None,
1887
+ inputs_embeds: Optional[torch.Tensor] = None,
1888
+ labels: Optional[torch.Tensor] = None,
1889
+ output_attentions: Optional[bool] = None,
1890
+ output_hidden_states: Optional[bool] = None,
1891
+ return_dict: Optional[bool] = None,
1892
+ unpad_inputs: Optional[bool] = None,
1893
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
1894
+ r"""
1895
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1896
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1897
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1898
+ `input_ids` above)
1899
+ """
1900
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1901
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1902
+
1903
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1904
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1905
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1906
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1907
+ inputs_embeds = (
1908
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1909
+ if inputs_embeds is not None
1910
+ else None
1911
+ )
1912
+
1913
+ outputs = self.new(
1914
+ input_ids,
1915
+ attention_mask=attention_mask,
1916
+ token_type_ids=token_type_ids,
1917
+ position_ids=position_ids,
1918
+ head_mask=head_mask,
1919
+ inputs_embeds=inputs_embeds,
1920
+ output_attentions=output_attentions,
1921
+ output_hidden_states=output_hidden_states,
1922
+ return_dict=return_dict,
1923
+ unpad_inputs=unpad_inputs,
1924
+ )
1925
+
1926
+ pooled_output = outputs[1]
1927
+
1928
+ pooled_output = self.dropout(pooled_output)
1929
+ logits = self.classifier(pooled_output)
1930
+ reshaped_logits = logits.view(-1, num_choices)
1931
+
1932
+ loss = None
1933
+ if labels is not None:
1934
+ loss_fct = nn.CrossEntropyLoss()
1935
+ loss = loss_fct(reshaped_logits, labels)
1936
+
1937
+ if not return_dict:
1938
+ output = (reshaped_logits,) + outputs[2:]
1939
+ return ((loss,) + output) if loss is not None else output
1940
+
1941
+ return MultipleChoiceModelOutput(
1942
+ loss=loss,
1943
+ logits=reshaped_logits,
1944
+ hidden_states=outputs.hidden_states,
1945
+ attentions=outputs.attentions,
1946
+ )
1947
+
1948
+ class NomicBertForTokenClassification(NomicBertPreTrainedModel):
1949
+ def __init__(self, config):
1950
+ super().__init__(config)
1951
+ self.num_labels = config.num_labels
1952
+
1953
+ self.bert = NomicBertModel(config, add_pooling_layer=False)
1954
+ classifier_dropout = (
1955
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1956
+ )
1957
+ self.dropout = nn.Dropout(classifier_dropout)
1958
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1959
+
1960
+ # Initialize weights and apply final processing
1961
+ self.post_init()
1962
+
1963
+ def forward(
1964
+ self,
1965
+ input_ids: Optional[torch.Tensor] = None,
1966
+ attention_mask: Optional[torch.Tensor] = None,
1967
+ token_type_ids: Optional[torch.Tensor] = None,
1968
+ position_ids: Optional[torch.Tensor] = None,
1969
+ head_mask: Optional[torch.Tensor] = None,
1970
+ inputs_embeds: Optional[torch.Tensor] = None,
1971
+ labels: Optional[torch.Tensor] = None,
1972
+ output_attentions: Optional[bool] = None,
1973
+ output_hidden_states: Optional[bool] = None,
1974
+ return_dict: Optional[bool] = None,
1975
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1976
+ r"""
1977
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1978
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1979
+ """
1980
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1981
+
1982
+ outputs = self.bert(
1983
+ input_ids,
1984
+ attention_mask=attention_mask,
1985
+ token_type_ids=token_type_ids,
1986
+ position_ids=position_ids,
1987
+ head_mask=head_mask,
1988
+ inputs_embeds=inputs_embeds,
1989
+ output_attentions=output_attentions,
1990
+ output_hidden_states=output_hidden_states,
1991
+ return_dict=return_dict,
1992
+ )
1993
+
1994
+ sequence_output = outputs[0]
1995
+
1996
+ sequence_output = self.dropout(sequence_output)
1997
+ logits = self.classifier(sequence_output)
1998
+
1999
+ loss = None
2000
+ if labels is not None:
2001
+ loss_fct = CrossEntropyLoss()
2002
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
2003
+
2004
+ if not return_dict:
2005
+ output = (logits,) + outputs[2:]
2006
+ return ((loss,) + output) if loss is not None else output
2007
+
2008
+ return TokenClassifierOutput(
2009
+ loss=loss,
2010
+ logits=logits,
2011
+ hidden_states=outputs.hidden_states,
2012
+ attentions=outputs.attentions,
2013
+ )
2014
+
2015
+ class NomicBertForQuestionAnswering(NomicBertPreTrainedModel):
2016
+ def __init__(self, config):
2017
+ super().__init__(config)
2018
+ self.num_labels = config.num_labels
2019
+
2020
+ self.bert = NomicBertModel(config, add_pooling_layer=False)
2021
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
2022
+
2023
+ # Initialize weights and apply final processing
2024
+ self.post_init()
2025
+
2026
+ def forward(
2027
+ self,
2028
+ input_ids: Optional[torch.Tensor] = None,
2029
+ attention_mask: Optional[torch.Tensor] = None,
2030
+ token_type_ids: Optional[torch.Tensor] = None,
2031
+ position_ids: Optional[torch.Tensor] = None,
2032
+ head_mask: Optional[torch.Tensor] = None,
2033
+ inputs_embeds: Optional[torch.Tensor] = None,
2034
+ start_positions: Optional[torch.Tensor] = None,
2035
+ end_positions: Optional[torch.Tensor] = None,
2036
+ output_attentions: Optional[bool] = None,
2037
+ output_hidden_states: Optional[bool] = None,
2038
+ return_dict: Optional[bool] = None,
2039
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
2040
+ r"""
2041
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
2042
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
2043
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
2044
+ are not taken into account for computing the loss.
2045
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
2046
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
2047
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
2048
+ are not taken into account for computing the loss.
2049
+ """
2050
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2051
+
2052
+ outputs = self.bert(
2053
+ input_ids,
2054
+ attention_mask=attention_mask,
2055
+ token_type_ids=token_type_ids,
2056
+ position_ids=position_ids,
2057
+ head_mask=head_mask,
2058
+ inputs_embeds=inputs_embeds,
2059
+ output_attentions=output_attentions,
2060
+ output_hidden_states=output_hidden_states,
2061
+ return_dict=return_dict,
2062
+ )
2063
+
2064
+ sequence_output = outputs[0]
2065
+
2066
+ logits = self.qa_outputs(sequence_output)
2067
+ start_logits, end_logits = logits.split(1, dim=-1)
2068
+ start_logits = start_logits.squeeze(-1).contiguous()
2069
+ end_logits = end_logits.squeeze(-1).contiguous()
2070
+
2071
+ total_loss = None
2072
+ if start_positions is not None and end_positions is not None:
2073
+ # If we are on multi-GPU, split add a dimension
2074
+ if len(start_positions.size()) > 1:
2075
+ start_positions = start_positions.squeeze(-1)
2076
+ if len(end_positions.size()) > 1:
2077
+ end_positions = end_positions.squeeze(-1)
2078
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
2079
+ ignored_index = start_logits.size(1)
2080
+ start_positions = start_positions.clamp(0, ignored_index)
2081
+ end_positions = end_positions.clamp(0, ignored_index)
2082
+
2083
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
2084
+ start_loss = loss_fct(start_logits, start_positions)
2085
+ end_loss = loss_fct(end_logits, end_positions)
2086
+ total_loss = (start_loss + end_loss) / 2
2087
+
2088
+ if not return_dict:
2089
+ output = (start_logits, end_logits) + outputs[2:]
2090
+ return ((total_loss,) + output) if total_loss is not None else output
2091
+
2092
+ return QuestionAnsweringModelOutput(
2093
+ loss=total_loss,
2094
+ start_logits=start_logits,
2095
+ end_logits=end_logits,
2096
+ hidden_states=outputs.hidden_states,
2097
+ attentions=outputs.attentions,
2098
+ )
2099
 
2100
  def hf_vit_config_to_vit_config(vit_config: ViTConfig) -> GPT2Config:
2101
  return GPT2Config(