1 Star 0 Fork 0

yzjia/handson-ml

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
data_cache.py 4.18 KB
一键复制 编辑 原始数据 按行查看 历史
yzjia 提交于 2020-06-15 17:08 . change the behavior of fetch_openml.
import os
import pickle
import sqlalchemy as sa
from sklearn.datasets import fetch_openml as sk_fetch_openml
from sqlalchemy import Column, String, BigInteger, create_engine, and_, or_
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
BASE_DIR = os.path.dirname(__file__)
ONLINE_DATA_DIR = os.path.join(BASE_DIR, 'online_data')
if not os.path.isdir(ONLINE_DATA_DIR):
os.mkdir(ONLINE_DATA_DIR)
engine = create_engine(
f'sqlite:///{os.path.join(ONLINE_DATA_DIR, "data_info.db")}')
Base = declarative_base(bind=engine)
class OnlineDataInfo(Base):
__tablename__ = 'online_data_profile'
data_id = Column(String(5), primary_key=True)
data_name = Column(String(128))
version = Column(BigInteger)
file_name = Column(String(200))
def __repr__(self):
return f'<id: {self.data_id} name: {self.data_name} version: {self.version}>'
Base.metadata.create_all()
Session = sessionmaker(bind=engine)
def save_online_data(data_obj, data_name):
with open(os.path.join(ONLINE_DATA_DIR, data_name + '.pkl'), 'wb') as file_obj:
return pickle.dump(data_obj, file_obj)
def load_online_data(data_name):
with open(os.path.join(ONLINE_DATA_DIR, data_name + '.pkl'), 'rb') as file_obj:
return pickle.load(file_obj)
def fetch_openml_from_cache(session: Session, name=None, version='active', data_id=None):
if data_id:
data_info = session.query(OnlineDataInfo).filter_by(
data_id=data_id).first()
else:
if isinstance(version, (int, float)) or \
(isinstance(version, str) and version.isnumeric()):
data_info = session.query(OnlineDataInfo).filter_by(
data_name=name, version=int(version)).first()
else:
data_info = session.query(OnlineDataInfo).filter_by(
data_name=name).order_by(OnlineDataInfo.version.desc()).first()
if data_info:
return load_online_data(data_info.file_name)
else:
raise FileNotFoundError(
f'not downloaded yet: {{name: {name}, version: {version}, data_id: {data_id}}}')
def fetch_openml_online(session: Session, name=None, version='active', data_id=None, data_home=None,
target_column='default-target', cache=True, return_X_y=False):
data = sk_fetch_openml(name=name, version=version, data_id=data_id, data_home=data_home,
target_column=target_column, cache=cache, return_X_y=return_X_y)
data_id, data_name, version = \
data.details.get('id'), data.details.get('name'), \
data.details.get('version')
file_name = f'{data_name}-{version}'
old_data_info = session.query(OnlineDataInfo).filter_by(data_id=data_id).first()
if old_data_info:
session.delete(old_data_info)
save_online_data(data, file_name)
data_info = OnlineDataInfo(
data_id=data_id,
data_name=data_name,
version=int(version),
file_name=file_name
)
session.add(data_info)
session.commit()
return data
def fetch_openml(name=None, version='active', data_id=None, data_home=None,
target_column='default-target', cache=True, return_X_y=False,
from_cache=True):
session = Session()
try:
if from_cache:
try:
return fetch_openml_from_cache(
session, name=name, version=version, data_id=data_id)
except FileNotFoundError:
while True:
download_online = input('data not found, download online now?[Y|N]')
print(download_online[:1].lower())
if download_online[:1].lower() == 'n':
return
else:
break
return fetch_openml_online(
session, name=name, version=version, data_id=data_id,
data_home=data_home, target_column=target_column, cache=cache,
return_X_y=return_X_y)
finally:
session.close()
if __name__ == '__main__':
session = Session()
data_list = session.query(OnlineDataInfo).first()
print(data_list)
data = fetch_openml(name='mnist_784')
# print(data.details)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/yzjia0827/handson-ml.git
git@gitee.com:yzjia0827/handson-ml.git
yzjia0827
handson-ml
handson-ml
master

搜索帮助