from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig import torch # Model name and configuration model_name = "ruslanmv/Medical-Llama3-8B" device_map = "auto" bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, ) # Load the model and tokenizer model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=bnb_config, trust_remote_code=True, use_cache=False, device_map=device_map, ) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) # Set pad_token_id to eos_token_id if None if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id # Define the chat template chat_template = """<|im_start|>system {system} <|im_end|> <|im_start|>user {user} <|im_end|> <|im_start|>assistant """ tokenizer.chat_template = chat_template # Function to generate a response def askme(question): sys_message = """ You are an AI Medical Assistant trained on a vast dataset of health information. Please be thorough and provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help. """ # Structure messages for the chat messages = [{"role": "system", "content": sys_message}, {"role": "user", "content": question}] # Apply the chat template prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer(prompt, return_tensors="pt").to("cuda") # Generate response outputs = model.generate(**inputs, max_new_tokens=100, use_cache=True) # Decode and clean up the response response_text = tokenizer.decode(outputs[0], skip_special_tokens=True) if "<|im_start|>assistant" in response_text: response_text = response_text.split("<|im_start|>assistant")[-1].strip() return response_text # Example usage question = """ I'm a 35-year-old male and for the past few months, I've been experiencing fatigue, increased sensitivity to cold, and dry, itchy skin. Could these symptoms be related to hypothyroidism? If so, what steps should I take to get a proper diagnosis and discuss treatment options? """ print(askme(question))