Edit model card

iTab-LLM

iTab-LLM is the Llama-2 7B model further trained with massive tables. This model is pretrained dedicating to solving the predictive tasks related to tabular data. For the details of our model, please refer to our paper: Unleashing the Potential of Large Language Models for Predictive Tabular Tasks in Data Science link

Demo Usage

Classification

from transformers import LlamaForSequenceClassification

model_name_or_path = "OldBirdAZ/itab-llm"
model = LlamaForSequenceClassification.from_pretrained(
    model_name_or_path,
    num_labels=num_labels,
)
tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id

Regression

You could build model resemble to LlamaForSequenceClassification, outputing to single numerical value. The model can be finetuned with the optimization of minimizing MSE loss.

Zero-shot Prediction

from transformers import AutoModelForCausalLM
import tensor_parallel as tp


model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path, 
    use_flash_attention_2="flash_attention_2", 
    torch_dtype=torch.bfloat16
)
model = tp.tensor_parallel(model, sharded=True)


prompt_str = "YOUR-PROMPT"
# fillin_missing_val_prompt_str = "### Instruction: Please fill in the missing value(s) in the table in Markdown format. The missing values are marked with placeholders: <missing_value_0>, <missing_value_1>, <missing_value_2>, ... The description of this table is: Historical cryptocurrency prices for the top 50 coins, including Open, High, Low, Volume, and Change % for each date.\n\n### Input:\n| low | high | sno | open | vol. | change % | date | price |\n| -------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- |\n| 59.81 | 63.035 | 768.0 | 61.703 | 4320000.0 | -0.78 | 2018-09-30 | 61.224 |\n| 59.472 | 62.137 | 769.0 | 61.225 | 4480000.0 | -1.34 | 2018-10-01 | 60.401 |\n| 59.231 | 61.835 | 770.0 | 60.392 | 4430000.0 | -1.32 | 2018-10-02 | 59.606 |\n| 56.745 | 59.704 | 771.0 | 59.606 | 4780000.0 | -3.46 | 2018-10-03 | 57.541 |\n| 57.457 | 60.079 | 772.0 | 57.562 | 3260000.0 | 1.51 | 2018-10-04 | 58.411 |\n| 57.672 | 59.82 | 773.0 | 58.41 | 4630000.0 | 1.01 | 2018-10-05 | 59.001 |\n| 56.692 | 59.169 | 774.0 | 59.065 | 4730000.0 | -1.78 | 2018-10-06 | 57.951 |\n| 56.986 | 58.533 | 775.0 | 57.951 | 1500000.0 | 0.56 | 2018-10-07 | 58.273 |\n| 57.693 | 60.163 | 776.0 | 58.274 | 2260000.0 | 2.37 | 2018-10-08 | 59.655 |\n| 58.523 | 59.887 | 777.0 | 59.655 | 2230000.0 | -1.15 | 2018-10-09 | 58.968 |\n| 50.968 | 54.149 | 780.0 | 51.263 | 2170000.0 | 4.99 | 2018-10-12 | 53.806 |\n| 52.233 | 61.175 | 783.0 | 52.545 | 3210000.0 | 6.48 | 2018-10-15 | 55.951 |\n| 54.619 | 56.809 | 784.0 | 55.95 | 1860000.0 | -0.97 | 2018-10-16 | 55.408 |\n| 54.339 | 55.571 | 785.0 | 55.416 | 1960000.0 | 0.14 | 2018-10-17 | 55.484 |\n| 52.965 | 54.47 | 787.0 | 53.431 | 2190000.0 | 0.4 | 2018-10-19 | 53.638 |\n| 53.476 | 54.751 | 789.0 | 54.233 | 2400000.0 | -0.94 | 2018-10-21 | 53.721 |\n| 52.78 | 54.176 | 790.0 | 53.686 | 2240000.0 | -1.15 | 2018-10-22 | 53.105 |\n| 50.648 | 54.694 | 791.0 | 53.122 | 2470000.0 | 0.41 | 2018-10-23 | 53.32 |\n| 52.894 | 53.933 | 792.0 | 53.325 | 2910000.0 | -0.44 | 2018-10-24 | 53.088 |\n| 52.677 | 53.293 | 793.0 | 53.094 | 2100000.0 | -0.16 | 2018-10-25 | 53.003 |\n| 52.382 | 53.447 | 794.0 | 53.003 | 2570000.0 | -0.75 | 2018-10-26 | 52.607 |\n| 51.994 | 53.202 | 795.0 | 52.608 | 2190000.0 | -0.46 | 2018-10-27 | 52.364 |\n| 48.161 | 52.501 | 797.0 | 51.958 | 2710000.0 | -5.14 | 2018-10-29 | 49.314 |\n| 48.915 | 50.071 | 798.0 | 49.307 | 1520000.0 | 0.0 | 2018-10-30 | 49.314 |\n| 48.209 | 50.652 | 799.0 | 49.317 | 2510000.0 | 1.28 | 2018-10-31 | 49.943 |\n| 49.851 | 50.781 | 800.0 | 49.952 | 2050000.0 | 1.31 | 2018-11-01 | <missing_value_0> |\n| 48.241 | 52.113 | 801.0 | 50.595 | 2280000.0 | 2.21 | 2018-11-02 | 51.716 |\n| 48.437 | 56.253 | 803.0 | 51.115 | 2280000.0 | 6.76 | 2018-11-04 | 54.573 |\n| 52.94 | 55.129 | 804.0 | 54.572 | 2020000.0 | -1.84 | 2018-11-05 | 53.57 |\n| 51.256 | 56.477 | 805.0 | 53.567 | 1690000.0 | 5.28 | 2018-11-06 | 56.401 |\n\n### Response: "
input_ids = prompt['input_ids'].to(model.device)
with torch.no_grad():
  response_result = model.generate(
      input_ids,
      max_new_tokens=max_dec_len,
      output_scores=True,
      return_dict_in_generate=True,
      num_return_sequences=1,
      remove_invalid_values=True,
  )
response = tokenizer.decode(response_result["sequences"][0][input_ids.shape[1]:], skip_special_tokens=True).strip()
result["generated_text"] = response.split("\n")[0].strip()

Ethical Considerations and Limitations

This model is the further pretrained version of Llama-2 7B over tables. Because the pretraining data mainly collected from Kaggle, you are required to rigorously follows Kaggle's terms and licensing agreements, adhering to legal and ethical standards if you would like to use this model. In addition, you also need to adhere the corresponding license and requirement of Llama-2 7B. Testing conducted to date has been in English, and has not covered, nor could it cover all scenarios. For these reasons, as with all LLMs, iTab-LLM’s potential outputs cannot be predicted in advance, and the model may in some instances produce inaccurate, biased or other objectionable responses to user prompts. Therefore, before deploying any applications of this model or applications based on this model, developers should perform safety testing and tuning tailored to their specific applications of the model.

Downloads last month
627
Safetensors
Model size
6.74B params
Tensor type
BF16
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.