johnsmith253325 commited on
Commit
6e4855e
1 Parent(s): c6d16d4

修正大小写和路径问题

Browse files
modules/models/StableLM.py CHANGED
@@ -4,7 +4,7 @@ import time
4
  import numpy as np
5
  from torch.nn import functional as F
6
  import os
7
- from base_model import BaseLLMModel
8
 
9
  class StopOnTokens(StoppingCriteria):
10
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
 
4
  import numpy as np
5
  from torch.nn import functional as F
6
  import os
7
+ from .base_model import BaseLLMModel
8
 
9
  class StopOnTokens(StoppingCriteria):
10
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
modules/models/base_model.py CHANGED
@@ -44,7 +44,7 @@ class ModelType(Enum):
44
  model_type = ModelType.LLaMA
45
  elif "xmchat" in model_name_lower:
46
  model_type = ModelType.XMChat
47
- elif "StableLM" in model_name_lower:
48
  model_type = ModelType.StableLM
49
  else:
50
  model_type = ModelType.Unknown
 
44
  model_type = ModelType.LLaMA
45
  elif "xmchat" in model_name_lower:
46
  model_type = ModelType.XMChat
47
+ elif "stablelm" in model_name_lower:
48
  model_type = ModelType.StableLM
49
  else:
50
  model_type = ModelType.Unknown
modules/models/models.py CHANGED
@@ -578,7 +578,7 @@ def get_model(
578
  access_key = os.environ.get("XMCHAT_API_KEY")
579
  model = XMChat(api_key=access_key)
580
  elif model_type == ModelType.StableLM:
581
- from StableLM import StableLM_Client
582
  model = StableLM_Client(model_name)
583
  elif model_type == ModelType.Unknown:
584
  raise ValueError(f"未知模型: {model_name}")
 
578
  access_key = os.environ.get("XMCHAT_API_KEY")
579
  model = XMChat(api_key=access_key)
580
  elif model_type == ModelType.StableLM:
581
+ from .StableLM import StableLM_Client
582
  model = StableLM_Client(model_name)
583
  elif model_type == ModelType.Unknown:
584
  raise ValueError(f"未知模型: {model_name}")