1 Star 0 Fork 0

zetingh/PaddleNLP

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
hubconf.py 3.44 KB
一键复制 编辑 原始数据 按行查看 历史
Jiaqi Liu 提交于 2021-04-27 22:47 . Fix hub bug (#319)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
dependencies = ['paddle', 'jieba', 'colorlog', 'colorama', 'seqeval']
import os
from paddlenlp.transformers import BertForPretraining, BertModel, BertForSequenceClassification
from paddlenlp.transformers import BertForTokenClassification, BertForQuestionAnswering
from paddlenlp.transformers import BertTokenizer
_BERT_MODEL_CLASSES = {
"bert": BertModel,
"sequence_classification": BertForSequenceClassification,
"token_classification": BertForTokenClassification,
"question_answering": BertForQuestionAnswering,
"pretraining": BertForPretraining
}
_BERT_PRETRAINED_MODELS = [
'bert-base-uncased', 'bert-large-uncased', 'bert-base-multilingual-uncased',
'bert-base-cased', 'bert-base-chinese', 'bert-large-cased',
'bert-base-multilingual-cased', 'bert-wwm-chinese', 'bert-wwm-ext-chinese'
]
def bert(model_name_or_path='bert-base-uncased',
model_select='sequence_classification'):
"""
Returns BERT model and tokenizer from given pretrained model name or path
and class type of tasks, such as sequence classification.
Args:
model_name_or_path (str, optional): A name of or a file path to a
pretrained model. It could be 'bert-base-uncased',
'bert-large-uncased', 'bert-base-multilingual-uncased',
'bert-base-cased', 'bert-base-chinese', 'bert-large-cased',
'bert-base-multilingual-cased', 'bert-wwm-chinese' or
'bert-wwm-ext-chinese'. Default: 'bert-base-uncased'.
model_select (str, optional): model class to select. It could be
'bert', 'sequence_classification', 'token_classification',
'question_answering' or 'pretraining'. If 'sequence_classification'
is chosen, model class would be `BertForSequenceClassification`.
The document of BERT model could be seen at `bert.modeling
<https://paddlenlp.readthedocs.io/zh/latest/source/paddlenlp.transformers.bert.modeling.html>`_
Default: 'sequence_classification'.
Returns:
tuple: Returns the pretrained bert model and bert tokenizer.
Example:
.. code-block:: python
import paddle.hub as hub
model, tokenizer = hub.load('PaddlePaddle/PaddleNLP:develop', model='bert', model_name_or_path='bert-base-cased')
"""
assert model_name_or_path in _BERT_PRETRAINED_MODELS or os.path.isdir(model_name_or_path), \
"Please check your model name or path. Supported model names are: {}.".format(tuple(_BERT_PRETRAINED_MODELS))
assert model_select in _BERT_MODEL_CLASSES.keys(), \
"Please check `model_select`, it should be in {}.".format(tuple(_BERT_MODEL_CLASSES.keys()))
model_class = _BERT_MODEL_CLASSES[model_select]
model = model_class.from_pretrained(model_name_or_path)
tokenizer = BertTokenizer.from_pretrained(model_name_or_path)
return model, tokenizer
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zetingh/PaddleNLP.git
git@gitee.com:zetingh/PaddleNLP.git
zetingh
PaddleNLP
PaddleNLP
develop

搜索帮助