From ad19a6225148dc13e85ab2c53cce6e7261e8ee2f Mon Sep 17 00:00:00 2001
From: Megan Henning <meganhenning@flywheel.io>
Date: Mon, 27 Feb 2017 16:29:50 -0600
Subject: [PATCH] Bug fixes

---
 api/auth/authproviders.py | 28 ++++++++++++++--------------
 api/web/base.py           |  3 +++
 2 files changed, 17 insertions(+), 14 deletions(-)

diff --git a/api/auth/authproviders.py b/api/auth/authproviders.py
index 95ea6e14..431bb2ee 100644
--- a/api/auth/authproviders.py
+++ b/api/auth/authproviders.py
@@ -1,6 +1,6 @@
 import requests
 import json
-from . import APIAuthProviderException, APIUnknownUserException
+from . import APIAuthProviderException
 from .. import config, util
 
 
@@ -26,7 +26,7 @@ class AuthProvider(object):
         when auth_type is dynamic.
         """
         if auth_type in AuthProviders:
-            provider_class = AuthProviders[auth_type].value
+            provider_class = AuthProviders[auth_type]
             return provider_class()
         else:
             raise NotImplementedError('Auth type {} is not supported'.format(auth_type))
@@ -35,13 +35,13 @@ class AuthProvider(object):
 class JWTAuthProvider(AuthProvider):
 
     def __init__(self):
-        super(JWTAuthProvider,self).__init__(AuthProviders.ldap.key)
+        super(JWTAuthProvider,self).__init__('ldap')
 
-    def validate_code(code):
+    def validate_code(self, code):
         uid = self.validate_user_exists(code)
         return code, None, uid
 
-    def validate_user_exists(token):
+    def validate_user_exists(self, token):
         r = requests.post(self.config['id_endpoint'], data={'token': token})
         if not r.ok:
             raise APIAuthProviderException('User token not valid')
@@ -54,16 +54,16 @@ class JWTAuthProvider(AuthProvider):
 class GoogleOAuthProvider(AuthProvider):
 
     def __init__(self):
-        super(GoogleAuthProvider,self).__init__(AuthProviders.google.key)
+        super(GoogleOAuthProvider,self).__init__('google')
 
-    def validate_code(code):
+    def validate_code(self, code):
         payload = {
             'client_id':        self.config['client_id'],
             'client_secret':    self.config['client_secret'],
             'code':             code,
             'grant_type':       'authorization_code'
         }
-        r = requests.post(self.config['token_url'], data=payload)
+        r = requests.post(self.config['token_endpoint'], data=payload)
         if not r.ok:
             raise APIAuthProviderException('User code not valid')
 
@@ -74,7 +74,7 @@ class GoogleOAuthProvider(AuthProvider):
 
         return token, refresh_token, uid
 
-    def validate_user_exists(token):
+    def validate_user_exists(self, token):
         r = requests.get(self.config['id_endpoint'], headers={'Authorization': 'Bearer ' + token})
         if not r.ok:
             raise APIAuthProviderException('User token not valid')
@@ -87,16 +87,16 @@ class GoogleOAuthProvider(AuthProvider):
 class WechatOAuthProvider(AuthProvider):
 
     def __init__(self):
-        super(WechatAuthProvider,self).__init__(AuthProviders.wechat.key)
+        super(WechatOAuthProvider,self).__init__('wechat')
 
-    def validate_code(code):
+    def validate_code(self, code):
         payload = {
             'appid':        self.config['client_id'],
             'secret':       self.config['client_secret'],
             'code':         code,
             'grant_type':   'authorization_code'
         }
-        r = requests.post(self.config['token_url'], params=payload)
+        r = requests.post(self.config['token_endpoint'], params=payload)
         if not r.ok:
             raise APIAuthProviderException('User code not valid')
 
@@ -107,8 +107,8 @@ class WechatOAuthProvider(AuthProvider):
 
         return token, refresh_token, uid
 
-AuthProviders = util.Enum('AuthProviders', {
+AuthProviders = {
     'google'    : GoogleOAuthProvider,
     'ldap'      : JWTAuthProvider,
     'wechat'    : WechatOAuthProvider
-})
+}
diff --git a/api/web/base.py b/api/web/base.py
index eb4e03c0..d9299a84 100644
--- a/api/web/base.py
+++ b/api/web/base.py
@@ -16,6 +16,7 @@ from .. import config
 from ..types import Origin
 from .. import validators
 from ..auth.authproviders import AuthProvider
+from ..auth import APIAuthProviderException
 from ..dao import APIConsistencyException, APIConflictException, APINotFoundException, APIPermissionException, APIValidationException, dbutil
 from ..dao.hierarchy import get_parent_tree
 from ..web.request import log_access, AccessType
@@ -341,6 +342,8 @@ class RequestHandler(webapp2.RequestHandler):
         elif isinstance(exception, validators.InputValidationException):
             code = 400
             self.request.logger.warning(str(exception))
+        elif isinstance(exception, APIAuthProviderException):
+            code = 401
         elif isinstance(exception, APIConsistencyException):
             code = 400
         elif isinstance(exception, APIPermissionException):
-- 
GitLab