背景简介
FastAPI 项目添加 JWT模块。
前置信息
- Python 3.11.13 【Conda - 创建 Python 环境】
- Redis 7 【Redis - 基于 Docker 部署】
详细信息
文件架构
my_project/
├── app/
│ ├── core/ # 核心功能模块
│ │ ├── __init__.py # 导出 settings, TokenData 等,定义包的公共 API
│ │ ├── config.py # 应用配置,包含 settings 对象
│ │ ├── token.py # Token 生成和密码哈希 (我们优化的文件)
│ │ └── auth.py # 认证和验证逻辑 (我们优化的文件)
│ │
│ ├── models/
│ │ ├── __init__.py
│ │ └── user_basic_info.py # 用户基本信息模型
│ │
│ └── main.py # FastAPI 应用实例的入口文件
│
├── tests/ # 测试目录
│ ├── __init__.py # (可选) 将 'tests' 标记为包
│ ├── conftest.py # pytest 共享 fixtures 和配置
│ ├── test_core_token.py # token.py 模块的单元测试
│ └── test_core_auth.py # auth.py 模块的单元测试
│
├── .env # 环境变量文件 (不应提交到版本控制)
├── .gitignore
├── requirements.txt # 项目生产依赖
├── requirements-dev.txt # 项目开发和测试依赖
└── README.md
依赖库
# JWT 库
PyJWT
bcrypt
# 测试库
pytest
pytest-mock
代码准备
- 修改
.env
# JWT settings
SECRET_KEY=your_very_strong_secret_key
ACCESS_TOKEN_EXPIRE_MINUTES=60
ALGORITHM="HS256"
- 修改
app/core/config.py
# app/core/config.py
"""Application settings configuration module.
This module defines the Settings class for managing application configuration
using Pydantic BaseSettings. It loads configuration from environment variables
and a .env file located in the project root directory.
"""
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Application settings configuration class.
This class manages all application configuration settings, including
basic app information and database configuration. It automatically
loads settings from environment variables and a `.env` file located
in the project's current working directory.
"""
# --- Basic Configuration ---
ENV_NAME: str #: The name of the environment (e.g., 'development', 'production').
APP_NAME: str = "Personal Website Backend Service" #: The name of the application.
VERSION: str = "1.0.0" #: The version of the application.
DEBUG: bool = True #: A flag to enable or disable debug mode.
# --- API setttings ---
API_V1_STR: str = "/api/v1" # API version
# Logging Configuration
LOG_LEVEL: str = "INFO"
LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
# --- JWT settings ---
SECRET_KEY: str
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 # Token expired after 1 hour
ALGORITHM: str = "HS256" # Encrypt algorithm with HS256
# --- Database Configuration ---
DATABASE_URL: str #: The database connection URL.
DB_ECHO_STR: str = 'False' #: Database echo setting as a string from the .env file.
DB_ECHO: bool = bool(DB_ECHO_STR) #: If True, SQLAlchemy will log all SQL statements.
# redis settings
REDIS_URL: str # Redis connection url from .env file
model_config = SettingsConfigDict(
env_file="./.env",
env_file_encoding="utf-8",
extra="ignore",
case_sensitive=False
)
# Global settings instance
settings = Settings()
token 模块
- 新建
app/core/token.py
"""
Token generation utilities.
This module provides functions for creating JWT access tokens and
hashing passwords.
"""
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Any, Mapping, Optional
import bcrypt
import jwt
from app.core.config import settings
# Defines the public API of this module
__all__ = [
"TokenData",
"get_password_hash",
"create_access_token",
]
@dataclass
class TokenData:
"""Represents the data payload within a JWT.
Attributes:
sub (Optional[str]): The subject of the token, typically the user ID.
"""
sub: Optional[str] = None
def get_password_hash(password: str) -> str:
"""Generates a bcrypt hash for a given password.
Args:
password (str): The plain-text password to hash.
Returns:
str: The resulting bcrypt hash.
"""
salt = bcrypt.gensalt()
hashed_password = bcrypt.hashpw(password.encode('utf-8'), salt)
return hashed_password.decode('utf-8')
def create_access_token(
data: Mapping[str, Any], expires_delta: Optional[timedelta] = None
) -> str:
"""Creates a JWT access token.
Args:
data (Mapping[str, Any]): The payload to include in the token
(e.g., {'sub': user_id}).
expires_delta (Optional[timedelta]): The token's expiration time.
If None, a default value from settings is used.
Returns:
str: The encoded JWT.
"""
to_encode = dict(data) # Create a mutable copy from the mapping
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
可选: 认证模块
- 创建用户信息表,作为认证参考 【Python - 基于 SQLAlchemy + Alembic + PostgreSQL 建表】
- 创建
app/core/auth.py
"""
Authentication and verification utilities.
This module provides functions for verifying passwords, decoding JWTs,
and authenticating users.
"""
from typing import Optional
import bcrypt
import jwt
from jwt import PyJWTError
from app.core.config import settings
from app.core.token import TokenData
from app.models.user_basic_info import UserBasicInfo
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verifies a plain password against a bcrypt hash.
Args:
plain_password (str): The plain-text password to verify.
hashed_password (str): The bcrypt hash stored in the database.
Returns:
bool: True if the password matches the hash, False otherwise.
"""
return bcrypt.checkpw(
plain_password.encode('utf-8'), hashed_password.encode('utf-8')
)
def verify_token(token: str) -> Optional[TokenData]:
"""Decodes and verifies a JWT token.
Args:
token (str): The JWT string to verify.
Returns:
Optional[TokenData]: The token's payload data if valid, otherwise None.
"""
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
user_id: str = payload.get("sub")
if user_id is None:
return None
token_data = TokenData(sub=user_id)
except PyJWTError:
return None
return token_data
def authenticate_user(db_user: UserBasicInfo, password: str) -> bool:
"""Authenticates a user by verifying their password.
Args:
db_user (UserBasicInfo): The user object retrieved from the database.
password (str): The plain-text password provided by the user.
Returns:
bool: True if authentication is successful, False otherwise.
"""
if not db_user:
return False
if not verify_password(password, db_user.password):
return False
return True
测试代码
- 新建
tests/test_token.py
"""
Tests for the app.core.token module.
"""
import jwt
from datetime import datetime, timedelta, timezone
import pytest
from app.core.token import TokenData, create_access_token, get_password_hash
# --- Fixtures ---
@pytest.fixture
def mock_settings(mocker):
"""Mocks the settings object specifically for the token module."""
# 关键:我们 patch 的是 token 模块内部看到的 settings
mock = mocker.patch("app.core.token.settings")
mock.SECRET_KEY = "a-super-secret-test-key"
mock.ALGORITHM = "HS256"
mock.ACCESS_TOKEN_EXPIRE_MINUTES = 30
return mock
@pytest.fixture
def sample_password():
"""Provides a sample plain-text password."""
return "mysecretpassword"
# --- Tests for get_password_hash ---
def test_get_password_hash(sample_password):
"""Tests that get_password_hash returns a valid bcrypt hash."""
hashed = get_password_hash(sample_password)
assert isinstance(hashed, str)
assert hashed != sample_password
assert hashed.startswith("$2b$")
# --- Tests for create_access_token ---
def test_create_access_token_with_default_expiry(mock_settings):
"""Tests token creation with the default expiration time."""
data = {"sub": "test-user-id"}
token = create_access_token(data)
assert isinstance(token, str)
payload = jwt.decode(
token, mock_settings.SECRET_KEY, algorithms=[mock_settings.ALGORITHM]
)
assert payload["sub"] == "test-user-id"
assert "exp" in payload
exp_timestamp = payload["exp"]
expected_exp = datetime.now(timezone.utc) + timedelta(minutes=mock_settings.ACCESS_TOKEN_EXPIRE_MINUTES)
assert abs(exp_timestamp - expected_exp.timestamp()) < 2
def test_create_access_token_with_custom_expiry(mock_settings):
"""Tests token creation with a custom expiration time."""
data = {"sub": "another-user-id"}
custom_delta = timedelta(seconds=60)
token = create_access_token(data, expires_delta=custom_delta)
payload = jwt.decode(
token, mock_settings.SECRET_KEY, algorithms=[mock_settings.ALGORITHM]
)
assert payload["sub"] == "another-user-id"
exp_timestamp = payload["exp"]
expected_exp = datetime.now(timezone.utc) + custom_delta
assert abs(exp_timestamp - expected_exp.timestamp()) < 2
# --- Tests for TokenData ---
def test_token_data_creation():
"""Tests the TokenData dataclass."""
token_data_with_sub = TokenData(sub="user-123")
assert token_data_with_sub.sub == "user-123"
token_data_default = TokenData()
assert token_data_default.sub is None
- 新建
tests/auth.py
"""
Tests for the app.core.auth module.
"""
import jwt
import bcrypt
from datetime import datetime, timedelta
import pytest
from jwt import PyJWTError
from app.core.auth import authenticate_user, verify_password, verify_token
from app.core.token import TokenData
# --- Fixtures ---
@pytest.fixture
def mock_settings(mocker):
"""Mocks the settings object specifically for the auth module."""
# 关键:我们 patch 的是 auth 模块内部看到的 settings
mock = mocker.patch("app.core.auth.settings")
mock.SECRET_KEY = "a-super-secret-test-key"
mock.ALGORITHM = "HS256"
return mock
@pytest.fixture
def mock_user(mocker):
"""Mocks a UserBasicInfo object from the database."""
user = mocker.MagicMock()
user.password = bcrypt.hashpw("correctpassword".encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
return user
@pytest.fixture
def valid_token(mock_settings):
"""Creates a valid JWT for testing."""
return jwt.encode(
{"sub": "test-user-id"}, mock_settings.SECRET_KEY, algorithm=mock_settings.ALGORITHM
)
@pytest.fixture
def expired_token(mock_settings):
"""Creates an expired JWT for testing."""
return jwt.encode(
{"sub": "test-user-id", "exp": datetime.utcnow() - timedelta(minutes=10)},
mock_settings.SECRET_KEY,
algorithm=mock_settings.ALGORITHM
)
# --- Tests for verify_password ---
@pytest.mark.parametrize(
"plain_password, hashed_password, expected",
[
("correctpassword", bcrypt.hashpw("correctpassword".encode('utf-8'), bcrypt.gensalt()).decode('utf-8'), True),
("wrongpassword", bcrypt.hashpw("correctpassword".encode('utf-8'), bcrypt.gensalt()).decode('utf-8'), False),
]
)
def test_verify_password(plain_password, hashed_password, expected):
"""Tests password verification against a bcrypt hash."""
assert verify_password(plain_password, hashed_password) is expected
# --- Tests for verify_token ---
def test_verify_token_success(valid_token):
"""Tests successful token verification."""
token_data = verify_token(valid_token)
assert isinstance(token_data, TokenData)
assert token_data.sub == "test-user-id"
def test_verify_token_invalid_signature(mock_settings):
"""Tests verification failure due to an invalid signature."""
invalid_token = jwt.encode(
{"sub": "test-user-id"}, "wrong-secret", algorithm=mock_settings.ALGORITHM
)
assert verify_token(invalid_token) is None
def test_verify_token_expired(expired_token):
"""Tests verification failure due to an expired token."""
assert verify_token(expired_token) is None
def test_verify_token_missing_sub(mock_settings):
"""Tests verification failure when 'sub' claim is missing."""
token_no_sub = jwt.encode(
{"user": "test-user-id"}, mock_settings.SECRET_KEY, algorithm=mock_settings.ALGORITHM
)
assert verify_token(token_no_sub) is None
def test_verify_token_malformed():
"""Tests verification failure with a completely malformed token."""
assert verify_token("this.is.not.a.valid.jwt") is None
# --- Tests for authenticate_user ---
def test_authenticate_user_success(mock_user):
"""Tests successful user authentication."""
assert authenticate_user(mock_user, "correctpassword") is True
def test_authenticate_user_wrong_password(mock_user):
"""Tests authentication failure with a wrong password."""
assert authenticate_user(mock_user, "wrongpassword") is False
def test_authenticate_user_user_not_found():
"""Tests authentication failure when the user does not exist."""
assert authenticate_user(None, "anypassword") is False
验证
- 执行 pytest
======================================================================================= test session starts =======================================================================================
platform linux -- Python 3.11.13, pytest-8.3.3, pluggy-1.6.0 -- /home/myserver/pers/pers-website/env/python3d11d13/bin/python3.11
cachedir: .pytest_cache
rootdir: /home/myserver/pers/pers-website/backend
configfile: pytest.ini
testpaths: tests
plugins: anyio-4.11.0, mock-3.14.0, asyncio-0.24.0
asyncio: mode=Mode.AUTO, default_loop_scope="function"
collected 21 items
********
[ 57%]
tests/test_core_auth.py::test_verify_token_missing_sub PASSED [ 61%]
tests/test_core_auth.py::test_verify_token_malformed PASSED [ 66%]
tests/test_core_auth.py::test_authenticate_user_success PASSED [ 71%]
tests/test_core_auth.py::test_authenticate_user_wrong_password PASSED [ 76%]
tests/test_core_auth.py::test_authenticate_user_user_not_found PASSED [ 80%]
tests/test_core_token.py::test_get_password_hash PASSED [ 85%]
tests/test_core_token.py::test_create_access_token_with_default_expiry PASSED [ 90%]
tests/test_core_token.py::test_create_access_token_with_custom_expiry PASSED [ 95%]
tests/test_core_token.py::test_token_data_creation PASSED [100%]
以上便是本文的全部内容,感谢您的阅读,如遇到任何问题,欢迎在评论区留言讨论。