Henry
发布于 2025-11-30 / 14 阅读
0
0

Python - FastAPI - 添加 JWT 模块

背景简介

FastAPI 项目添加 JWT模块。

前置信息

详细信息

文件架构

my_project/
├── app/                          
│   ├── core/                     # 核心功能模块
│   │   ├── __init__.py           # 导出 settings, TokenData 等,定义包的公共 API
│   │   ├── config.py             # 应用配置,包含 settings 对象
│   │   ├── token.py              # Token 生成和密码哈希 (我们优化的文件)
│   │   └── auth.py               # 认证和验证逻辑 (我们优化的文件)
│   │
│   ├── models/                  
│   │   ├── __init__.py
│   │   └── user_basic_info.py    # 用户基本信息模型
│   │
│   └── main.py                   # FastAPI 应用实例的入口文件
│
├── tests/                        # 测试目录
│   ├── __init__.py               # (可选) 将 'tests' 标记为包
│   ├── conftest.py               # pytest 共享 fixtures 和配置
│   ├── test_core_token.py        # token.py 模块的单元测试
│   └── test_core_auth.py         # auth.py 模块的单元测试
│
├── .env                          # 环境变量文件 (不应提交到版本控制)
├── .gitignore                    
├── requirements.txt              # 项目生产依赖
├── requirements-dev.txt          # 项目开发和测试依赖
└── README.md                     

依赖库

# JWT 库
PyJWT
bcrypt

# 测试库
pytest
pytest-mock

代码准备

  • 修改 .env
# JWT settings
SECRET_KEY=your_very_strong_secret_key
ACCESS_TOKEN_EXPIRE_MINUTES=60
ALGORITHM="HS256"
  • 修改 app/core/config.py
# app/core/config.py
"""Application settings configuration module.

This module defines the Settings class for managing application configuration
using Pydantic BaseSettings. It loads configuration from environment variables
and a .env file located in the project root directory.
"""

from pydantic_settings import BaseSettings, SettingsConfigDict


class Settings(BaseSettings):
    """Application settings configuration class.

    This class manages all application configuration settings, including
    basic app information and database configuration. It automatically
    loads settings from environment variables and a `.env` file located
    in the project's current working directory.
    """
    
    # --- Basic Configuration ---
    ENV_NAME: str  #: The name of the environment (e.g., 'development', 'production').
    APP_NAME: str = "Personal Website Backend Service"  #: The name of the application.
    VERSION: str = "1.0.0"  #: The version of the application.
    DEBUG: bool = True  #: A flag to enable or disable debug mode.

    # --- API setttings ---
    API_V1_STR: str = "/api/v1" # API version

    # Logging Configuration
    LOG_LEVEL: str = "INFO"
    LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"

    # --- JWT settings ---
    SECRET_KEY: str
    ACCESS_TOKEN_EXPIRE_MINUTES: int = 60  # Token expired after 1 hour
    ALGORITHM: str = "HS256" # Encrypt algorithm with HS256

    # --- Database Configuration ---
    DATABASE_URL: str  #: The database connection URL.
    DB_ECHO_STR: str = 'False'  #: Database echo setting as a string from the .env file.
    DB_ECHO: bool = bool(DB_ECHO_STR)  #: If True, SQLAlchemy will log all SQL statements.
  
    # redis settings
    REDIS_URL: str # Redis connection url from .env file


    model_config = SettingsConfigDict(
        env_file="./.env", 
        env_file_encoding="utf-8", 
        extra="ignore",
        case_sensitive=False
    )


# Global settings instance
settings = Settings()

token 模块

  • 新建 app/core/token.py
"""
Token generation utilities.

This module provides functions for creating JWT access tokens and
hashing passwords.
"""
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Any, Mapping, Optional

import bcrypt
import jwt

from app.core.config import settings

# Defines the public API of this module
__all__ = [
    "TokenData",
    "get_password_hash",
    "create_access_token",
]


@dataclass
class TokenData:
    """Represents the data payload within a JWT.

    Attributes:
        sub (Optional[str]): The subject of the token, typically the user ID.
    """
    sub: Optional[str] = None


def get_password_hash(password: str) -> str:
    """Generates a bcrypt hash for a given password.

    Args:
        password (str): The plain-text password to hash.

    Returns:
        str: The resulting bcrypt hash.
    """
    salt = bcrypt.gensalt()
    hashed_password = bcrypt.hashpw(password.encode('utf-8'), salt)
    return hashed_password.decode('utf-8')


def create_access_token(
    data: Mapping[str, Any], expires_delta: Optional[timedelta] = None
) -> str:
    """Creates a JWT access token.

    Args:
        data (Mapping[str, Any]): The payload to include in the token
            (e.g., {'sub': user_id}).
        expires_delta (Optional[timedelta]): The token's expiration time.
            If None, a default value from settings is used.

    Returns:
        str: The encoded JWT.
    """
    to_encode = dict(data)  # Create a mutable copy from the mapping
    if expires_delta:
        expire = datetime.now(timezone.utc) + expires_delta
    else:
        expire = datetime.now(timezone.utc) + timedelta(
            minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
        )

    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(
        to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
    )
    return encoded_jwt


可选: 认证模块

"""
Authentication and verification utilities.

This module provides functions for verifying passwords, decoding JWTs,
and authenticating users.
"""

from typing import Optional

import bcrypt
import jwt
from jwt import PyJWTError

from app.core.config import settings
from app.core.token import TokenData
from app.models.user_basic_info import UserBasicInfo


def verify_password(plain_password: str, hashed_password: str) -> bool:
    """Verifies a plain password against a bcrypt hash.

    Args:
        plain_password (str): The plain-text password to verify.
        hashed_password (str): The bcrypt hash stored in the database.

    Returns:
        bool: True if the password matches the hash, False otherwise.
    """
    return bcrypt.checkpw(
        plain_password.encode('utf-8'), hashed_password.encode('utf-8')
    )


def verify_token(token: str) -> Optional[TokenData]:
    """Decodes and verifies a JWT token.

    Args:
        token (str): The JWT string to verify.

    Returns:
        Optional[TokenData]: The token's payload data if valid, otherwise None.
    """
    try:
        payload = jwt.decode(
            token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
        )
        user_id: str = payload.get("sub")
        if user_id is None:
            return None
        token_data = TokenData(sub=user_id)
    except PyJWTError:
        return None
    return token_data


def authenticate_user(db_user: UserBasicInfo, password: str) -> bool:
    """Authenticates a user by verifying their password.

    Args:
        db_user (UserBasicInfo): The user object retrieved from the database.
        password (str): The plain-text password provided by the user.

    Returns:
        bool: True if authentication is successful, False otherwise.
    """
    if not db_user:
        return False
    if not verify_password(password, db_user.password):
        return False
    return True

测试代码

  • 新建 tests/test_token.py


"""
Tests for the app.core.token module.
"""

import jwt
from datetime import datetime, timedelta, timezone

import pytest
from app.core.token import TokenData, create_access_token, get_password_hash


# --- Fixtures ---

@pytest.fixture
def mock_settings(mocker):
    """Mocks the settings object specifically for the token module."""
    # 关键:我们 patch 的是 token 模块内部看到的 settings
    mock = mocker.patch("app.core.token.settings")
    mock.SECRET_KEY = "a-super-secret-test-key"
    mock.ALGORITHM = "HS256"
    mock.ACCESS_TOKEN_EXPIRE_MINUTES = 30
    return mock


@pytest.fixture
def sample_password():
    """Provides a sample plain-text password."""
    return "mysecretpassword"


# --- Tests for get_password_hash ---

def test_get_password_hash(sample_password):
    """Tests that get_password_hash returns a valid bcrypt hash."""
    hashed = get_password_hash(sample_password)
    assert isinstance(hashed, str)
    assert hashed != sample_password
    assert hashed.startswith("$2b$")


# --- Tests for create_access_token ---

def test_create_access_token_with_default_expiry(mock_settings):
    """Tests token creation with the default expiration time."""
    data = {"sub": "test-user-id"}
    token = create_access_token(data)
    
    assert isinstance(token, str)
    
    payload = jwt.decode(
        token, mock_settings.SECRET_KEY, algorithms=[mock_settings.ALGORITHM]
    )
    assert payload["sub"] == "test-user-id"
    assert "exp" in payload
    
    exp_timestamp = payload["exp"]
    expected_exp = datetime.now(timezone.utc) + timedelta(minutes=mock_settings.ACCESS_TOKEN_EXPIRE_MINUTES)
    assert abs(exp_timestamp - expected_exp.timestamp()) < 2


def test_create_access_token_with_custom_expiry(mock_settings):
    """Tests token creation with a custom expiration time."""
    data = {"sub": "another-user-id"}
    custom_delta = timedelta(seconds=60)
    token = create_access_token(data, expires_delta=custom_delta)
    
    payload = jwt.decode(
        token, mock_settings.SECRET_KEY, algorithms=[mock_settings.ALGORITHM]
    )
    assert payload["sub"] == "another-user-id"
    
    exp_timestamp = payload["exp"]
    expected_exp = datetime.now(timezone.utc) + custom_delta
    assert abs(exp_timestamp - expected_exp.timestamp()) < 2


# --- Tests for TokenData ---

def test_token_data_creation():
    """Tests the TokenData dataclass."""
    token_data_with_sub = TokenData(sub="user-123")
    assert token_data_with_sub.sub == "user-123"
    
    token_data_default = TokenData()
    assert token_data_default.sub is None

  • 新建 tests/auth.py

"""
Tests for the app.core.auth module.
"""

import jwt
import bcrypt
from datetime import datetime, timedelta

import pytest
from jwt import PyJWTError

from app.core.auth import authenticate_user, verify_password, verify_token
from app.core.token import TokenData


# --- Fixtures ---

@pytest.fixture
def mock_settings(mocker):
    """Mocks the settings object specifically for the auth module."""
    # 关键:我们 patch 的是 auth 模块内部看到的 settings
    mock = mocker.patch("app.core.auth.settings")
    mock.SECRET_KEY = "a-super-secret-test-key"
    mock.ALGORITHM = "HS256"
    return mock

@pytest.fixture
def mock_user(mocker):
    """Mocks a UserBasicInfo object from the database."""
    user = mocker.MagicMock()
    user.password = bcrypt.hashpw("correctpassword".encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
    return user


@pytest.fixture
def valid_token(mock_settings):
    """Creates a valid JWT for testing."""
    return jwt.encode(
        {"sub": "test-user-id"}, mock_settings.SECRET_KEY, algorithm=mock_settings.ALGORITHM
    )

@pytest.fixture
def expired_token(mock_settings):
    """Creates an expired JWT for testing."""
    return jwt.encode(
        {"sub": "test-user-id", "exp": datetime.utcnow() - timedelta(minutes=10)},
        mock_settings.SECRET_KEY,
        algorithm=mock_settings.ALGORITHM
    )


# --- Tests for verify_password ---

@pytest.mark.parametrize(
    "plain_password, hashed_password, expected",
    [
        ("correctpassword", bcrypt.hashpw("correctpassword".encode('utf-8'), bcrypt.gensalt()).decode('utf-8'), True),
        ("wrongpassword", bcrypt.hashpw("correctpassword".encode('utf-8'), bcrypt.gensalt()).decode('utf-8'), False),
    ]
)
def test_verify_password(plain_password, hashed_password, expected):
    """Tests password verification against a bcrypt hash."""
    assert verify_password(plain_password, hashed_password) is expected


# --- Tests for verify_token ---

def test_verify_token_success(valid_token):
    """Tests successful token verification."""
    token_data = verify_token(valid_token)
    assert isinstance(token_data, TokenData)
    assert token_data.sub == "test-user-id"

def test_verify_token_invalid_signature(mock_settings):
    """Tests verification failure due to an invalid signature."""
    invalid_token = jwt.encode(
        {"sub": "test-user-id"}, "wrong-secret", algorithm=mock_settings.ALGORITHM
    )
    assert verify_token(invalid_token) is None

def test_verify_token_expired(expired_token):
    """Tests verification failure due to an expired token."""
    assert verify_token(expired_token) is None

def test_verify_token_missing_sub(mock_settings):
    """Tests verification failure when 'sub' claim is missing."""
    token_no_sub = jwt.encode(
        {"user": "test-user-id"}, mock_settings.SECRET_KEY, algorithm=mock_settings.ALGORITHM
    )
    assert verify_token(token_no_sub) is None

def test_verify_token_malformed():
    """Tests verification failure with a completely malformed token."""
    assert verify_token("this.is.not.a.valid.jwt") is None


# --- Tests for authenticate_user ---

def test_authenticate_user_success(mock_user):
    """Tests successful user authentication."""
    assert authenticate_user(mock_user, "correctpassword") is True

def test_authenticate_user_wrong_password(mock_user):
    """Tests authentication failure with a wrong password."""
    assert authenticate_user(mock_user, "wrongpassword") is False

def test_authenticate_user_user_not_found():
    """Tests authentication failure when the user does not exist."""
    assert authenticate_user(None, "anypassword") is False

验证

  • 执行 pytest
======================================================================================= test session starts =======================================================================================
platform linux -- Python 3.11.13, pytest-8.3.3, pluggy-1.6.0 -- /home/myserver/pers/pers-website/env/python3d11d13/bin/python3.11
cachedir: .pytest_cache
rootdir: /home/myserver/pers/pers-website/backend
configfile: pytest.ini
testpaths: tests
plugins: anyio-4.11.0, mock-3.14.0, asyncio-0.24.0
asyncio: mode=Mode.AUTO, default_loop_scope="function"
collected 21 items                                                                                                                                                                                

********
                                                                                                           [ 57%]
tests/test_core_auth.py::test_verify_token_missing_sub PASSED                                                                                                                               [ 61%]
tests/test_core_auth.py::test_verify_token_malformed PASSED                                                                                                                                 [ 66%]
tests/test_core_auth.py::test_authenticate_user_success PASSED                                                                                                                              [ 71%]
tests/test_core_auth.py::test_authenticate_user_wrong_password PASSED                                                                                                                       [ 76%]
tests/test_core_auth.py::test_authenticate_user_user_not_found PASSED                                                                                                                       [ 80%]
tests/test_core_token.py::test_get_password_hash PASSED                                                                                                                                     [ 85%]
tests/test_core_token.py::test_create_access_token_with_default_expiry PASSED                                                                                                               [ 90%]
tests/test_core_token.py::test_create_access_token_with_custom_expiry PASSED                                                                                                                [ 95%]
tests/test_core_token.py::test_token_data_creation PASSED                                                                                                                                   [100%]

以上便是本文的全部内容,感谢您的阅读,如遇到任何问题,欢迎在评论区留言讨论。



评论