From 77feb2212b25b771a329de6401ea14e4acff4f0f Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 15 Aug 2024 20:25:03 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=88=A0=E9=99=A4=E3=80=81?= =?UTF-8?q?=E6=9B=B4=E6=96=B0=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- oauth2_provider/app/core/applications.py | 101 +++++++++++++++--- oauth2_provider/app/serialize/applications.py | 47 ++++++++ oauth2_provider/app/views/applications.py | 62 +++++++++-- oauth2_provider/database/table.py | 19 +++- oauth2_provider/urls.py | 2 +- 5 files changed, 202 insertions(+), 29 deletions(-) diff --git a/oauth2_provider/app/core/applications.py b/oauth2_provider/app/core/applications.py index c935bab..6c478f3 100644 --- a/oauth2_provider/app/core/applications.py +++ b/oauth2_provider/app/core/applications.py @@ -10,6 +10,8 @@ from vulcanus.restful.resp.state import ( DATA_EXIST, DATABASE_INSERT_ERROR, DATABASE_QUERY_ERROR, + DATABASE_UPDATE_ERROR, + DATABASE_DELETE_ERROR, SUCCEED, ) @@ -22,7 +24,7 @@ class ApplicationProxy: Application related table operation """ - def create_application(self, data) -> str: + def create_application(self, data: dict) -> str: """create application. Args: data (dict): @@ -46,13 +48,18 @@ class ApplicationProxy: client_id = gen_salt(24) client_name = data.get('client_name') client = OAuth2Client(client_id=client_id, user_id=data.get('user_id')) + client.user_id = data.get('user_id') client.client_id_issued_at = int(time.time()) + client.app_name = data.get('client_name') + scopes_set = set(["profile", "email", "openid", "phone", "offline_access"]) + for scope_item in self._split_by_crlf(data.get('allowed_scope')): + scopes_set.add(scope_item) client.set_client_metadata({ "client_name": data.get('client_name'), "client_uri": data.get('client_uri'), "skip_auth": data.get('skip_auth'), "register_callback_uri": data.get('register_callback_uri'), - "scope": data.get('allowed_scope'), + "scope": list(scopes_set), "redirect_uris": self._split_by_crlf(data.get('redirect_uris')), "grant_types": data.get('allowed_grant_types'), "respnoses_types": self._split_by_crlf(data.get('allowed_responses_types')), @@ -65,7 +72,7 @@ class ApplicationProxy: try: if not self._check_client_name_not_exist(client_name): LOGGER.error(f"create application failed, application exists: {client_name}") - return DATA_EXIST + return DATA_EXIST, dict() db.session.add(client) db.session.commit() LOGGER.debug("create application succeed.") @@ -73,12 +80,12 @@ class ApplicationProxy: LOGGER.error(error) LOGGER.error("create application failed.") db.session.rollback() - return DATABASE_INSERT_ERROR - return SUCCEED + return DATABASE_INSERT_ERROR, dict() + return SUCCEED, {"client_info": client.client_info, "client_metadata": client.client_metadata} def _check_client_name_not_exist(self, client_name: str): - query_res = db.session.query(OAuth2Client).filter_by(client_name=client_name).count() - if query_res != 0: + query_res = db.session.query(OAuth2Client).filter(OAuth2Client.app_name == client_name).count() + if query_res: return False return True @@ -87,14 +94,13 @@ class ApplicationProxy: return [] return [v for v in s.splitlines() if v] - def get_all_applications(self): + def get_all_applications(self, user_id: str): """Get all applications - Returns: Tuple[str, list]: status_code, applications_info """ try: - applications = db.session.query(OAuth2Client).all() + applications = db.session.query(OAuth2Client).filter(user_id==user_id).all() applications_info = [] for application in applications: applications_info.append({ @@ -107,20 +113,87 @@ class ApplicationProxy: return DATABASE_QUERY_ERROR, [] return SUCCEED, applications_info - def get_one_application(self, client_id: str): + def get_one_application(self, client_id: str, user_id: str): """Get one application + data (dict): + { + client_name = fields.String(required=True) + client_uri = fields.String(required=True) + skip_auth = fields.String(required=True) + register_callback_uri = fields.String(required=True) + allowed_scope = fields.String(required=True) + redirect_uris = fields.String(required=True) + allowed_grant_types = fields.String(required=True) + allowed_responses_types = fields.String(required=True) + token_endpoint_auth_method = fields.String(required=True) + } Returns: - Tuple[str, dict]: status_code, application + Tuple[str, dict]: status_code, application_info """ try: - application = db.session.query(OAuth2Client).filter_by(client_id=client_id).one_or_none() + application = db.session.query(OAuth2Client).filter( + OAuth2Client.client_id==client_id, + OAuth2Client.user_id==user_id + ).one_or_none() + if not application: + LOGGER.info(f'''no application refer to this client_id {client_id}, this user id {user_id}''') + return DATABASE_QUERY_ERROR, dict() application_info = { "client_info": application.client_info, "client_metadata": application.client_metadata } except sqlalchemy.exc.SQLAlchemyError as error: LOGGER.error(error) - LOGGER.err("get one account info failed") + LOGGER.error("get one application info failed") return DATABASE_QUERY_ERROR, dict() return SUCCEED, application_info + def update_one_application(self, user_id: str, client_id: str, data: dict): + """Update one application + Returns: + Tuple[str, dict]: status_code, application_info + """ + try: + scopes_set = set(["profile", "email", "openid", "phone", "offline_access"]) + for scope_item in self._split_by_crlf(data.get('allowed_scope')): + scopes_set.add(scope_item) + application = db.session.query(OAuth2Client).filter( + OAuth2Client.client_id==client_id + ).one() + metadata = application.client_metadata + for key, value in data: + metadata.update(key, value) + ret = db.session.query(OAuth2Client).filter( + OAuth2Client.client_id==client_id, + OAuth2Client.user_id==user_id + ).update({'client_metadata': application}) + db.session.commit() + if not ret: + LOGGER.info(f'''no application refer to this client_id {client_id}, this user id {user_id}''') + return DATABASE_UPDATE_ERROR + except sqlalchemy.exc.SQLAlchemyError as error: + LOGGER.error(error) + LOGGER.error("update one application info failed") + return DATABASE_UPDATE_ERROR + return SUCCEED + + def delete_one_application(self, user_id: str, client_id: str): + """delete one application + Returns: + Tuple[str, dict]: status_code + """ + try: + ret = db.session.query(OAuth2Client).filter( + OAuth2Client.user_id==user_id, + OAuth2Client.client_id==client_id + ).delete() + if not ret: + LOGGER.info(f'''no application refer to this client_id {client_id}, this user id {user_id}''') + return DATABASE_DELETE_ERROR + db.session.commit() + except sqlalchemy.exc.SQLAlchemyError as error: + LOGGER.error(error) + LOGGER.error(f'''delete application error, client id is {client_id}, user_id is {user_id}''') + return DATABASE_DELETE_ERROR + return SUCCEED + diff --git a/oauth2_provider/app/serialize/applications.py b/oauth2_provider/app/serialize/applications.py index e5afb55..247b1b0 100644 --- a/oauth2_provider/app/serialize/applications.py +++ b/oauth2_provider/app/serialize/applications.py @@ -12,3 +12,50 @@ # ******************************************************************************/ # marshmallow +from marshmallow import Schema, fields, validate, ValidationError +from vulcanus.restful.serialize.validate import ValidateRules + +class Oauth2ClientSchema(Schema): + """ + validators for parameter of /user/account/change + """ + client_name = fields.String(required=True,validate=validate.Length(min=1,max=10)) + client_uri = fields.String(required=True,validate=validate.URL()) + skip_auth = fields.Boolean(default=False) + register_callback_uri = fields.String(validate=validate.URL()) + allowed_scope = fields.String(validate=validate.Length(min=1, max=10)) + redirect_uri = fields.String(required=True, validate=validate.URL()) + allowed_grant_types = fields.String(required=True, validate=validate.ContainsOnly( + ["code", "token"] + )) + allowed_responses_types = fields.String(required=True, validate=validate.ContainsOnly( + ["authorization_code", "client_credentials"] + )) + token_endpoint_auth_method = fields.String(required=True, validate=validate.ContainsOnly( + ["client_secret_basic", "client_secret_post"] + )) + +# { +# "client_name": "eulercopilot", +# "client_uri": "https://qa-robot-openeuler.test.osinfra.cn", +# "skip_auth": true, +# "register_callback_uri": "", +# "allowed_scope": "", +# "redirect_uris": "https://qa-robot-openeuler.test.osinfra.cn/api/callback", +# "allowed_grant_types": "authorization_code", +# "allowed_responses_types": "code", +# "token_endpoint_auth_method": "client_secret_basic" +# } + + +class UpdateOauth2ClientSchema(Schema): + client_id = fields.String() + client_uri = fields.String() + skip_auth = fields.Boolean() + register_callback_uri = fields.String() + allowed_scope = fields.String() + redirect_uris = fields.String() + allowed_grant_types = fields.String() + allowed_responses_types = fields.String() + token_endpoint_auth_method = fields.String() + diff --git a/oauth2_provider/app/views/applications.py b/oauth2_provider/app/views/applications.py index 4f0b8b4..c9bc3dd 100644 --- a/oauth2_provider/app/views/applications.py +++ b/oauth2_provider/app/views/applications.py @@ -13,8 +13,11 @@ from flask import request from vulcanus.log.log import LOGGER from vulcanus.restful.resp import state +from oauth2_provider.app.views import validate_request from vulcanus.restful.response import BaseResponse from oauth2_provider.app.core.applications import ApplicationProxy +from vulcanus.restful.resp.state import SUCCEED, PERMESSION_ERROR +from oauth2_provider.app.serialize.applications import Oauth2ClientSchema, UpdateOauth2ClientSchema class ApplicationsView(BaseResponse): @@ -23,8 +26,13 @@ class ApplicationsView(BaseResponse): """ def get(self): - status_code, applications = ApplicationProxy().get_all_applications() - return self.response(code=status_code, data=applications) + user_id = "123456" + status_code, applications = ApplicationProxy().get_all_applications(user_id) + ret_data = { + "number": len(applications), + "applications": applications + } + return self.response(code=status_code, data=ret_data) class ApplicationsRegisteView(BaseResponse): @@ -32,9 +40,14 @@ class ApplicationsRegisteView(BaseResponse): Application registration views """ - def post(self, **params: dict): - status_code = ApplicationProxy().create_application(params) - return self.response(code=status_code) + @validate_request(schema=Oauth2ClientSchema) + def post(self, request_body, **params): + user_id = "123456" + print(type(request_body)) + print(request_body) + request_body['user_id'] = user_id + status_code, application = ApplicationProxy().create_application(request_body) + return self.response(code=status_code, data=application) class ApplicationsDetailView(BaseResponse): @@ -42,11 +55,38 @@ class ApplicationsDetailView(BaseResponse): Application detail management views """ - def get(self, app_id): - pass + def get(self, client_id): + user_id = "123456" + status_code, application = ApplicationProxy().get_one_application( + client_id=client_id, + user_id=user_id + ) + return self.response(code=status_code, data=application) + - def put(self, app_id): - pass + @validate_request(schema=UpdateOauth2ClientSchema) + def put(self, request_body, **params): + manage_user = "123456" + status_code = ApplicationProxy().update_one_application( + user_id=manage_user, + client_id=request_body.get('client_id'), + data=request_body + ) + if status_code != SUCCEED: + return self.response(code=status_code, data=dict()) + else: + status_code, application = ApplicationProxy().get_one_application( + client_id=request_body.get('client_id'), + user_id=manage_user + ) + return self.response(code=status_code, data=application) - def delete(self, app_id): - pass + + def delete(self, client_id): + user_id="123456" + status_code = ApplicationProxy().delete_one_application( + client_id=client_id, + user_id=user_id + ) + return self.response(code=status_code) + diff --git a/oauth2_provider/database/table.py b/oauth2_provider/database/table.py index 919e95f..dbef2fc 100644 --- a/oauth2_provider/database/table.py +++ b/oauth2_provider/database/table.py @@ -52,16 +52,18 @@ class OAuth2Client(db.Model, OAuth2ClientMixin): __tablename__ = 'oauth2_client' id = Column(Integer, primary_key=True) + app_name = Column(String(48), unique=True, nullable=False) user_id = Column(Integer, ForeignKey('manage_user.id', ondelete='CASCADE')) user = relationship('ManageUser') @property - def skip_auth(self): - return self.client_metadata.get('skip_auth') + def skip_authorization(self): + return self.client_metadata.get('skip_authorization') @property def register_callback_uri(self): - return self.check_metadata.get('register_callback_uri') + return self.client_metadata.get('register_callback_uri') + class OAuth2Token(db.Model, OAuth2TokenMixin): @@ -130,3 +132,14 @@ class OAuth2AuthorizationCode(db.Model, OAuth2AuthorizationCodeMixin): nullable=False, ) client = relationship('OAuth2Client') + + +class UserAndOAuth2Client(db.Model): + __tablename__ = 'user_oauth2_client_scopes' + id = Column(Integer, primary_key=True) + user_id = Column(Integer, ForeignKey('user.id', ondelete='CASCADE')) + client_id = Column(Integer, ForeignKey('oauth2_client.id', ondelete='CASCADE')) + scopes = Column(String, nullable=False) + user = relationship('User') + client = relationship('OAuth2Client') + diff --git a/oauth2_provider/urls.py b/oauth2_provider/urls.py index 5e28558..6238e4f 100644 --- a/oauth2_provider/urls.py +++ b/oauth2_provider/urls.py @@ -19,7 +19,7 @@ URLS = [ # applications (ApplicationsView, "/oauth2/applications/"), (ApplicationsRegisteView, "/oauth2/applications/register/"), - (ApplicationsDetailView, "/oauth2/applications/"), + (ApplicationsDetailView, "/oauth2/applications/"), # oauth2 (OauthorizeView, "/oauth2/authorize/"), (OauthTokenView, "/oauth2/token/"), -- Gitee