|
--- |
|
base_model: unsloth/llama-3-8b-bnb-4bit |
|
language: |
|
- en |
|
license: apache-2.0 |
|
tags: |
|
- text-generation-inference |
|
- transformers |
|
- unsloth |
|
- llama |
|
- trl |
|
datasets: |
|
- Studeni/robot-instructions |
|
pipeline_tag: text-generation |
|
--- |
|
|
|
# Llama 3 8B Robot Instruction Model (4-bit) |
|
|
|
## Model description |
|
|
|
This model is a fine-tuned version of Llama 3 8B, optimized with Unsloth and quantized into 4-bit. |
|
It is designed to convert casual user input text into function calls for controlling industrial robots. |
|
The aim is to lower the barrier for individuals who do not have programming skills to control robots using simple text instructions. |
|
|
|
## Model Details |
|
- **Model ID:** Studeni/llama-3-8b-bnb-4bit-robot-instruct |
|
- **Architecture:** Llama 3 8B |
|
- **Quantization:** 4-bit |
|
- **Framework:** Transformers, Peft, Unsloth |
|
|
|
## Usage |
|
|
|
### Using Unsloth Library |
|
|
|
```python |
|
import json |
|
from datasets import load_dataset |
|
from unsloth import FastLanguageModel |
|
|
|
# Dataset |
|
repo_id = "Studeni/robot-instructions" |
|
dataset = load_dataset(repo_id, split="test") |
|
test_input = dataset[0]["input"] |
|
test_output = dataset[0]["output"] |
|
print(f"User input: {test_input}\nGround truth: {test_output}") |
|
|
|
# Prompt |
|
robot_instruct_prompt = """ |
|
### Instruction: |
|
Transform input into list of function calls for controlling industrial robots. |
|
|
|
### Input: |
|
{} |
|
|
|
### Response: |
|
{} |
|
""" |
|
|
|
# Model Parameters |
|
lora_id = "Studeni/llama-3-8b-bnb-4bit-robot-instruct" |
|
max_seq_length = 2048 |
|
dtype = None # Auto-detection. Use Float16 for Tesla T4, V100, Bfloat16 for Ampere+ |
|
load_in_4bit = True |
|
|
|
# Load the model and tokenizer |
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
model_name=lora_id, |
|
max_seq_length=max_seq_length, |
|
dtype=dtype, |
|
load_in_4bit=load_in_4bit, |
|
) |
|
FastLanguageModel.for_inference(model) |
|
|
|
# Tokenize input text |
|
inputs = tokenizer( |
|
[robot_instruct_prompt.format(test_input, "")], |
|
return_tensors="pt", |
|
).to("cuda") |
|
|
|
# Run generation |
|
outputs = model.generate(**inputs, max_new_tokens=64, use_cache=True) |
|
text_output = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
|
# Extracting function call and converting to json |
|
function_call = text_output[0].split("### Response:")[-1].strip() |
|
function_call = json.loads(function_call) |
|
for f in function_call: |
|
print(f"Function to call: {f['function']}") |
|
print(f"Input parameters: {f['kwargs']}") |
|
``` |
|
|
|
### Using Transformers and Peft |
|
```python |
|
import json |
|
from datasets import load_dataset |
|
from peft import AutoPeftModelForCausalLM |
|
from transformers import AutoTokenizer |
|
|
|
# Dataset |
|
repo_id = "Studeni/robot-instructions" |
|
dataset = load_dataset(repo_id, split="test") |
|
test_input = dataset[0]["input"] |
|
test_output = dataset[0]["output"] |
|
print(f"User input: {test_input}\nGround truth: {test_output}") |
|
|
|
# Prompt |
|
robot_instruct_prompt = """ |
|
### Instruction: |
|
Transform input into list of function calls for controlling industrial robots. |
|
|
|
### Input: |
|
{} |
|
|
|
### Response: |
|
{} |
|
""" |
|
|
|
# Model Parameters |
|
lora_id = "Studeni/llama-3-8b-bnb-4bit-robot-instruct" |
|
load_in_4bit = True |
|
|
|
# Load model and tokenizer |
|
model = AutoPeftModelForCausalLM.from_pretrained( |
|
pretrained_model_name_or_path=lora_id, |
|
load_in_4bit=load_in_4bit, |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(lora_id) |
|
|
|
# Tokenize input text |
|
inputs = tokenizer( |
|
[robot_instruct_prompt.format(test_input, "")], |
|
return_tensors="pt", |
|
).to("cuda") |
|
|
|
# Run generation |
|
outputs = model.generate(**inputs, max_new_tokens=256, use_cache=True) |
|
text_output = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
|
# Extracting function call and converting to json |
|
function_call = text_output[0].split("### Response:")[-1].strip() |
|
function_call = json.loads(function_call) |
|
for f in function_call: |
|
print(f"Function to call: {f['function']}") |
|
print(f"Input parameters: {f['kwargs']}") |
|
|
|
``` |
|
|
|
## Limitations and Future Work 🚨 |
|
This model is currently a work in progress and supports only three basic functions: `move_tcp`, `move_joint`, and `get_joint_values`. |
|
Future iterations will include a more comprehensive dataset with more complex commands and capabilities, better human-labeled data, and improved performance metrics. |
|
|
|
## Contributions and Collaborations 🤝 |
|
|
|
We welcome contributions and collaborations to help improve and expand the capabilities of this model. Whether you are interested in adding more complex functions, improving the dataset, or enhancing the model's performance, your input is valuable. |
|
You can add and contact me on [LinkedIn](https://www.linkedin.com/in/milutin-studen/). |
|
|
|
--- |
|
This llama model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library. |
|
|
|
[<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth) |