Spaces:
Sleeping
Sleeping
johnsmith253325
commited on
Commit
•
6e4855e
1
Parent(s):
c6d16d4
修正大小写和路径问题
Browse files- modules/models/StableLM.py +1 -1
- modules/models/base_model.py +1 -1
- modules/models/models.py +1 -1
modules/models/StableLM.py
CHANGED
@@ -4,7 +4,7 @@ import time
|
|
4 |
import numpy as np
|
5 |
from torch.nn import functional as F
|
6 |
import os
|
7 |
-
from base_model import BaseLLMModel
|
8 |
|
9 |
class StopOnTokens(StoppingCriteria):
|
10 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
|
|
4 |
import numpy as np
|
5 |
from torch.nn import functional as F
|
6 |
import os
|
7 |
+
from .base_model import BaseLLMModel
|
8 |
|
9 |
class StopOnTokens(StoppingCriteria):
|
10 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
modules/models/base_model.py
CHANGED
@@ -44,7 +44,7 @@ class ModelType(Enum):
|
|
44 |
model_type = ModelType.LLaMA
|
45 |
elif "xmchat" in model_name_lower:
|
46 |
model_type = ModelType.XMChat
|
47 |
-
elif "
|
48 |
model_type = ModelType.StableLM
|
49 |
else:
|
50 |
model_type = ModelType.Unknown
|
|
|
44 |
model_type = ModelType.LLaMA
|
45 |
elif "xmchat" in model_name_lower:
|
46 |
model_type = ModelType.XMChat
|
47 |
+
elif "stablelm" in model_name_lower:
|
48 |
model_type = ModelType.StableLM
|
49 |
else:
|
50 |
model_type = ModelType.Unknown
|
modules/models/models.py
CHANGED
@@ -578,7 +578,7 @@ def get_model(
|
|
578 |
access_key = os.environ.get("XMCHAT_API_KEY")
|
579 |
model = XMChat(api_key=access_key)
|
580 |
elif model_type == ModelType.StableLM:
|
581 |
-
from StableLM import StableLM_Client
|
582 |
model = StableLM_Client(model_name)
|
583 |
elif model_type == ModelType.Unknown:
|
584 |
raise ValueError(f"未知模型: {model_name}")
|
|
|
578 |
access_key = os.environ.get("XMCHAT_API_KEY")
|
579 |
model = XMChat(api_key=access_key)
|
580 |
elif model_type == ModelType.StableLM:
|
581 |
+
from .StableLM import StableLM_Client
|
582 |
model = StableLM_Client(model_name)
|
583 |
elif model_type == ModelType.Unknown:
|
584 |
raise ValueError(f"未知模型: {model_name}")
|