sagemaker-launcher / trainer.py
philschmid's picture
philschmid HF staff
online trainer
25f0c96
from sagemaker.huggingface import HuggingFace
import logging
import sys
from contextlib import contextmanager
from io import StringIO
from streamlit.report_thread import REPORT_CONTEXT_ATTR_NAME
from threading import current_thread
import streamlit as st
import sys
import sagemaker
import boto3
@contextmanager
def st_redirect(src, dst):
placeholder = st.empty()
output_func = getattr(placeholder, dst)
with StringIO() as buffer:
old_write = src.write
def new_write(b):
if getattr(current_thread(), REPORT_CONTEXT_ATTR_NAME, None):
buffer.write(b)
output_func(buffer.getvalue())
else:
old_write(b)
try:
src.write = new_write
yield
finally:
src.write = old_write
@contextmanager
def st_stdout(dst):
with st_redirect(sys.stdout, dst):
yield
@contextmanager
def st_stderr(dst):
with st_redirect(sys.stderr, dst):
yield
task2script = {
"text-classification": {
"entry_point": "run_glue.py",
"source_dir": "examples/text-classification",
},
"token-classification": {
"entry_point": "run_ner.py",
"source_dir": "examples/token-classification",
},
"question-answering": {
"entry_point": "run_qa.py",
"source_dir": "examples/question-answering",
},
"summarization": {
"entry_point": "run_summarization.py",
"source_dir": "examples/seq2seq",
},
"translation": {
"entry_point": "run_translation.py",
"source_dir": "examples/seq2seq",
},
"causal-language-modeling": {
"entry_point": "run_clm.py",
"source_dir": "examples/language-modeling",
},
"masked-language-modeling": {
"entry_point": "run_mlm.py",
"source_dir": "examples/language-modeling",
},
}
def train_estimtator(parameter, config):
with st_stdout("code"):
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.getLevelName("INFO"),
handlers=[logging.StreamHandler(sys.stdout)],
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger.info = print
# git configuration to download our fine-tuning script
git_config = {"repo": "https://github.com/huggingface/transformers.git", "branch": "v4.4.2"}
# creating fine-tuning script
entry_point = task2script[parameter["task"]]["entry_point"]
source_dir = task2script[parameter["task"]]["source_dir"]
# create train file
# iam configuration
session = boto3.session.Session(
aws_access_key_id=config["aws_access_key_id"],
aws_secret_access_key=config["aws_secret_accesskey"],
region_name=config["region"],
)
sess = sagemaker.Session(boto_session=session)
iam = session.client(
"iam", aws_access_key_id=config["aws_access_key_id"], aws_secret_access_key=config["aws_secret_accesskey"]
)
role = iam.get_role(RoleName=config["aws_sagemaker_role"])["Role"]["Arn"]
logger.info(f"role: {role}")
instance_type = config["instance_type"].split("|")[1].split("|")[0].strip()
logger.info(f"instance_type: {instance_type}")
hyperparameters = {
"output_dir": "/opt/ml/model",
"do_train": True,
"do_eval": True,
"do_predict": True,
**parameter,
}
del hyperparameters["task"]
# create estimator
huggingface_estimator = HuggingFace(
entry_point=entry_point,
source_dir=source_dir,
git_config=git_config,
base_job_name=config["job_name"],
instance_type=instance_type,
sagemaker_session=sess,
instance_count=config["instance_count"],
role=role,
transformers_version="4.4",
pytorch_version="1.6",
py_version="py36",
hyperparameters=hyperparameters,
)
# train
huggingface_estimator.fit()