将RestBean转为响应式 切换ORM框架使用Spring Data R2DBC 配置完成Security Configuration

This commit is contained in:
ydoily 2025-02-27 19:04:28 +08:00
parent 7c71f90431
commit 3d7011e51a
23 changed files with 649 additions and 417 deletions

49
pom.xml
View File

@ -21,6 +21,7 @@
<properties>
<java.version>21</java.version>
<jwt.version>4.4.0</jwt.version>
<r2dbc.mysql>1.4.0</r2dbc.mysql>
<lombok-version>1.18.30</lombok-version>
<netty.version>4.1.115.Final</netty.version>
<mybatis-plus.version>3.5.6</mybatis-plus.version>
@ -40,31 +41,10 @@
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-webflux</artifactId>
</dependency>
<!-- Netty 核心组件 -->
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-handler</artifactId>
<version>${netty.version}</version>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-codec-http</artifactId>
<version>${netty.version}</version>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-codec-http2</artifactId>
<version>${netty.version}</version>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-transport</artifactId>
<version>${netty.version}</version>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-resolver-dns-native-macos</artifactId>
<version>${netty.version}</version> <!-- 使用最新版本 -->
<version>${netty.version}</version>
<classifier>osx-aarch_64</classifier> <!-- Mac 芯片架构Intel 使用 osx-x86_64 -->
</dependency>
@ -102,26 +82,17 @@
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<!-- MyBatis-Plus -->
<!--r2dbc-->
<dependency>
<groupId>com.baomidou</groupId>
<artifactId>mybatis-plus-spring-boot3-starter</artifactId>
<version>${mybatis-plus.version}</version>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-r2dbc</artifactId>
</dependency>
<dependency>
<groupId>io.asyncer</groupId>
<artifactId>r2dbc-mysql</artifactId>
<version>${r2dbc.mysql}</version>
</dependency>
<dependency>
<groupId>com.baomidou</groupId>
<artifactId>mybatis-plus-generator</artifactId>
<version>${mybatis-plus.version}</version>
</dependency>
<!-- MySQL Driver -->
<dependency>
<groupId>com.mysql</groupId>
<artifactId>mysql-connector-j</artifactId>
<version>${mysql-connector-java.version}</version>
<scope>runtime</scope>
</dependency>
<!-- Spring Boot DevTools (可选, 开发热加载) -->
<dependency>

View File

@ -1,45 +0,0 @@
package com.example.config;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.web.embedded.netty.NettyReactiveWebServerFactory;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import reactor.netty.http.server.HttpServer;
@Slf4j
@Configuration
public class NettyConfig {
@Bean
public NettyReactiveWebServerFactory nettyReactiveWebServerFactory() {
NettyReactiveWebServerFactory factory = new NettyReactiveWebServerFactory();
factory.addServerCustomizers(this::apply);
return factory;
}
public HttpServer apply(HttpServer httpServer) {
// 创建一个 NioEventLoopGroup指定 8 Worker 线程处理 I/O 事件
EventLoopGroup workerGroup = new NioEventLoopGroup(8);
return httpServer
// 绑定 Worker 线程池负责处理 I/O 事件
.runOn(workerGroup)
// 设置 TCP 连接队列的最大长度为 128防止服务器过载
.option(ChannelOption.SO_BACKLOG, 128)
// 启用 TCP Keep-Alive保持长连接防止连接频繁关闭
.option(ChannelOption.SO_KEEPALIVE, true)
// 监听服务器成功绑定端口的事件并记录服务器启动的地址
.doOnBound(server -> log.info("Netty Server started on: {}", server.address()))
// 监听服务器解绑端口的事件并记录服务器已停止
.doOnUnbound(server -> log.info("Netty Server stopped."))
// 监听新的连接事件并打印远程客户端的地址
.doOnConnection(con -> log.info("Connected to Netty Server: {}", con.channel().remoteAddress()))
// 监听新的 Channel 初始化事件并记录新建的 Channel 信息
.doOnChannelInit((observer, channel, remoteAddr) ->
log.info("New channel initialized: {}", channel));
}
}

View File

@ -1,20 +0,0 @@
package com.example.config;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.CommandLineRunner;
import org.springframework.stereotype.Component;
@Component
@Slf4j
public class NettyConfigTest implements CommandLineRunner {
private final NettyProperties nettyProperties;
public NettyConfigTest(NettyProperties nettyProperties) {
this.nettyProperties = nettyProperties;
}
@Override
public void run(String... args) {
log.info("Netty 配置加载成功: {}", nettyProperties);
}
}

View File

@ -1,56 +0,0 @@
package com.example.config;
import lombok.Data;
import lombok.ToString;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;
import java.util.List;
@Data
@ToString
@Component
@ConfigurationProperties(prefix = "netty")
public class NettyProperties {
private WebSocketConfig websocket;
private NettyOptions options;
private NettyServer server;
private LoggingConfig logging;
private ConnectionConfig connections;
@Data
@ToString
public static class WebSocketConfig {
private int maxFrameSize;
private boolean allowExtensions;
private List<String> subProtocols;
}
@Data
@ToString
public static class NettyOptions {
private int soBacklog;
private boolean soReuseaddr;
private boolean tcpNodeLay;
private boolean keepAlive;
}
@Data
@ToString
public static class NettyServer {
private int port;
}
@Data
@ToString
public static class LoggingConfig {
private String level;
private String logFile;
}
@Data
@ToString
public static class ConnectionConfig {
private int maxClients;
private int timeoutSeconds;
}
}

View File

@ -1,13 +1,20 @@
package com.example.config;
import com.example.entity.AccountDetails;
import com.example.entity.RestBean;
import com.example.entity.vo.response.AuthorizeVo;
import com.example.filter.JwtAuthenticationFilter;
import com.example.service.AccountService;
import com.example.utils.JwtUtils;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity;
import org.springframework.security.config.web.server.SecurityWebFiltersOrder;
import org.springframework.security.config.web.server.ServerHttpSecurity;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
@ -17,14 +24,23 @@ import org.springframework.security.web.server.context.NoOpServerSecurityContext
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.Map;
@Slf4j
@Configuration
@EnableWebFluxSecurity // 启用 WebFlux 安全配置
public class SecurityConfiguration {
@Resource
private JwtUtils utils;
@Resource
private AccountService service;
@Resource
private JwtAuthenticationFilter filter;
@Bean
public SecurityWebFilterChain filterChain(ServerHttpSecurity http) {
http
@ -55,7 +71,8 @@ public class SecurityConfiguration {
.logout(logout -> logout
.logoutUrl("/chat/auth/logout")
.logoutSuccessHandler(this::onLogoutSuccess)
);
)
.addFilterAt(filter, SecurityWebFiltersOrder.AUTHENTICATION);
return http.build(); // 返回构建的安全过滤链
}
@ -67,21 +84,24 @@ public class SecurityConfiguration {
*/
private Mono<Void> onAuthenticationSuccess(WebFilterExchange webFilterExchange,
Authentication authentication) {
ServerWebExchange exchange = webFilterExchange.getExchange();
ServerHttpResponse response = exchange.getResponse();
ServerHttpResponse response = webFilterExchange.getExchange().getResponse();
AccountDetails user = (AccountDetails) authentication.getPrincipal();
// 设置响应状态码为 200 (OK)
response.setStatusCode(HttpStatus.OK);
// 告诉客户端它需要解析的数据格式
response.getHeaders().setContentType(MediaType.APPLICATION_JSON);
return service.findAccountByNameOrEmail(user.getUsername())
.flatMap(account -> {
String token = utils.generateJwt(user, account.id(), account.username());
// 创建 JSON 响应对象包含成功信息
RestBean<Map<String, String>> restBean = RestBean.success(Collections.singletonMap("token", null));
// 转换 Java 对象为 JSON 字符串作为响应的实际内容
String jsonResponse = restBean.asJsonString();
AuthorizeVo vo = account.asViewObject(AuthorizeVo.class, v -> {
v.setExpireTime(utils.expireTime());
v.setToken(token);
});
return response.writeWith(Mono.just(response.bufferFactory()
.wrap(jsonResponse.getBytes(StandardCharsets.UTF_8))));
return RestBean.writeSuccessToResponse(response, vo);
})
.onErrorResume(e -> {
log.error("认证处理失败", e);
return RestBean.writeFailureToResponse(response, 401, e.getMessage());
});
}
/**
@ -93,10 +113,14 @@ public class SecurityConfiguration {
public Mono<Void> onLogoutSuccess(WebFilterExchange exchange,
Authentication authentication) {
ServerHttpResponse response = exchange.getExchange().getResponse();
// 设置响应状态码为 200 (OK)
response.setStatusCode(HttpStatus.OK);
response.getHeaders().setContentType(MediaType.APPLICATION_JSON);
return null;
// 如果有JWT令牌将其加入黑名单
String authHeader = exchange.getExchange().getRequest().getHeaders().getFirst("Authorization");
if (authHeader != null) {
utils.invalidateJwt(authHeader);
}
return RestBean.writeSuccessToResponse(response);
}
/**
@ -107,17 +131,9 @@ public class SecurityConfiguration {
*/
private Mono<Void> onAuthenticationFailure(WebFilterExchange webFilterExchange,
AuthenticationException exception) {
ServerWebExchange exchange = webFilterExchange.getExchange();
ServerHttpResponse response = exchange.getResponse();
ServerHttpResponse response = webFilterExchange.getExchange().getResponse();
// 设置响应状态码为 401 (Unauthorized)
response.setStatusCode(HttpStatus.UNAUTHORIZED);
response.getHeaders().setContentType(MediaType.APPLICATION_JSON);
RestBean<String> restBean = RestBean.unauthorized(exception.getMessage());
String jsonResponse = restBean.asJsonString();
return response.writeWith(Mono.just(response.bufferFactory()
.wrap(jsonResponse.getBytes(StandardCharsets.UTF_8))));
return RestBean.writeUnauthorizedToResponse(response, exception.getMessage());
}
/**
@ -130,16 +146,7 @@ public class SecurityConfiguration {
AuthenticationException exception) {
ServerHttpResponse response = exchange.getResponse();
// 设置响应状态码为 401 (UNAUTHORIZED)
exchange.getResponse().setStatusCode(HttpStatus.UNAUTHORIZED);
response.getHeaders().setContentType(MediaType.APPLICATION_JSON);
RestBean<String> restBean = RestBean.unauthorized(exception.getMessage());
String errorResponse = restBean.asJsonString();
return exchange.getResponse()
.writeWith(Mono.just(exchange.getResponse()
.bufferFactory()
.wrap(errorResponse.getBytes(StandardCharsets.UTF_8))));
return RestBean.writeUnauthorizedToResponse(response, exception.getMessage());
}
/**
@ -152,15 +159,6 @@ public class SecurityConfiguration {
AccessDeniedException denied) {
ServerHttpResponse response = exchange.getResponse();
// 设置响应状态码为 403 (FORBIDDEN)
response.setStatusCode(HttpStatus.FORBIDDEN);
response.getHeaders().setContentType(MediaType.APPLICATION_JSON);
RestBean<String> restBean = RestBean.forbidden(denied.getMessage());
String errorResponse = restBean.asJsonString();
return exchange.getResponse()
.writeWith(Mono.just(exchange.getResponse()
.bufferFactory()
.wrap(errorResponse.getBytes(StandardCharsets.UTF_8))));
return RestBean.writeFailureToResponse(response, 403, denied.getMessage());
}
}
}

View File

@ -0,0 +1,21 @@
package com.example.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder;
import org.springframework.security.crypto.password.PasswordEncoder;
@Configuration
public class WebConfiguration {
@Bean
public PasswordEncoder passwordEncoder() {
return new BCryptPasswordEncoder();
}
public static void main(String[] args) {
BCryptPasswordEncoder encoder = new BCryptPasswordEncoder();
String password = "123456";
String encodedPassword = encoder.encode(password);
System.out.println("BCrypt encoded password for '123456': " + encodedPassword);
}
}

View File

@ -0,0 +1,34 @@
package com.example.controller;
import org.springframework.web.bind.annotation.*;
@RestController
@RequestMapping("/chat/test")
public class TestController {
@GetMapping("/hello")
public String hello() {
return "hello, world!";
}
@PostMapping("/addAccount")
public String addAccount() {
return "添加一个帐户";
}
@PutMapping("/updateAccount")
public String updateAccount() {
return "所有字段都要传";
}
@PatchMapping("/updateAccountInfo")
public String updateAccountInfo() {
return "只传要更新的字段";
}
@DeleteMapping("/deleteAccountById/{id}")
public String deleteAccountById(@PathVariable int id) {
return "删除一个帐户" + id;
}
}

View File

@ -0,0 +1,52 @@
package com.example.entity;
import com.example.entity.dto.AccountDTO;
import lombok.Getter;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.userdetails.UserDetails;
import java.util.Collection;
import java.util.Collections;
@Getter
public class AccountDetails implements UserDetails {
private final AccountDTO account;
private AccountDetails(AccountDTO account) {
this.account = account;
}
public static AccountDetailsBuilder withAccount(AccountDTO account) {
return new AccountDetailsBuilder(account);
}
public static class AccountDetailsBuilder {
private final AccountDTO account;
private AccountDetailsBuilder(AccountDTO account) {
this.account = account;
}
public AccountDetails build() {
return new AccountDetails(account);
}
}
@Override
public Collection<? extends GrantedAuthority> getAuthorities() {
return Collections.singletonList(
new SimpleGrantedAuthority("ROLE_" + account.role())
);
}
@Override
public String getPassword() {
return account.password();
}
@Override
public String getUsername() {
return account.username();
}
}

View File

@ -0,0 +1,36 @@
package com.example.entity;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.util.function.Consumer;
public interface BaseData {
default <V> V asViewObject(Class<V> clazz, Consumer<V> consumer) {
V v = this.asViewObject(clazz);
consumer.accept(v);
return v;
}
default <V> V asViewObject(Class<V> clazz) {
try {
Field[] declaredField = clazz.getDeclaredFields();
Constructor<V> constructor = clazz.getConstructor();
V v = constructor.newInstance();
for (Field field : declaredField) convert(field, v);
return v;
} catch (ReflectiveOperationException exception) {
throw new RuntimeException(exception.getMessage());
}
}
private void convert(Field field, Object vo) {
try {
Field source = this.getClass().getDeclaredField(field.getName());
field.setAccessible(true);
source.setAccessible(true);
field.set(vo, source.get(this));
} catch (IllegalAccessException | NoSuchFieldException ignored) {}
}
}

View File

@ -2,31 +2,105 @@ package com.example.entity;
import com.alibaba.fastjson2.JSONObject;
import com.alibaba.fastjson2.JSONWriter;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpResponse;
import reactor.core.publisher.Mono;
import java.nio.charset.StandardCharsets;
public record RestBean<T>(int code, T data, String message) {
/**
* 创建成功响应带数据
*/
public static <T> RestBean<T> success(T data) {
return new RestBean<>(200, data, "请求成功");
return new RestBean<>(HttpStatus.OK.value(), data, "请求成功");
}
/**
* 创建成功响应无数据
*/
public static <T> RestBean<T> success() {
return success(null);
}
/**
* 创建未授权响应
*/
public static <T> RestBean<T> unauthorized(String message) {
return failure(401, message);
return failure(HttpStatus.UNAUTHORIZED.value(), message);
}
/**
* 创建禁止访问响应
*/
public static <T> RestBean<T> forbidden(String message) {
return failure(403, message);
return failure(HttpStatus.FORBIDDEN.value(), message);
}
/**
* 创建自定义错误响应
*/
public static <T> RestBean<T> failure(int code, String message) {
return new RestBean<>(code, null, message);
}
/**
* 转换为JSON字符串
*/
public String asJsonString() {
return JSONObject.toJSONString(this, JSONWriter.Feature.WriteNulls);
}
/**
* 将该对象写入响应
* 适用于WebFlux响应式环境
*/
public Mono<Void> writeToResponse(ServerHttpResponse response) {
response.getHeaders().setContentType(MediaType.APPLICATION_JSON);
response.setStatusCode(HttpStatus.valueOf(this.code));
byte[] bytes = this.asJsonString().getBytes(StandardCharsets.UTF_8);
DataBuffer buffer = response.bufferFactory().wrap(bytes);
return response.writeWith(Mono.just(buffer));
}
/**
* 将成功响应写入响应对象带数据
* @param response 服务器HTTP响应对象
* @param data 响应数据
*/
public static <T> Mono<Void> writeSuccessToResponse(ServerHttpResponse response, T data) {
return success(data).writeToResponse(response);
}
/**
* 将成功响应写入响应对象无数据
* @param response 服务器HTTP响应对象
*/
public static Mono<Void> writeSuccessToResponse(ServerHttpResponse response) {
return success().writeToResponse(response);
}
/**
* 将失败响应写入响应对象
* @param response 服务器HTTP响应对象
* @param code HTTP状态码
* @param message 错误消息
*/
public static Mono<Void> writeFailureToResponse(ServerHttpResponse response, int code, String message) {
return failure(code, message).writeToResponse(response);
}
/**
* 将未授权响应写入响应对象
* @param response 服务器HTTP响应对象
* @param message 错误消息
*/
public static Mono<Void> writeUnauthorizedToResponse(ServerHttpResponse response, String message) {
return unauthorized(message).writeToResponse(response);
}
}

View File

@ -0,0 +1,23 @@
package com.example.entity.dto;
import com.example.entity.BaseData;
import lombok.Builder;
import org.springframework.data.annotation.Id;
import org.springframework.data.relational.core.mapping.Column;
import org.springframework.data.relational.core.mapping.Table;
import java.time.LocalDateTime;
@Table("db_account")
@Builder
public record AccountDTO(
@Id
Integer id,
@Column("username")
String username,
@Column("password")
String password,
@Column("email")
String email,
String role,
LocalDateTime registerTime
) implements BaseData {}

View File

@ -0,0 +1,13 @@
package com.example.entity.vo.response;
import lombok.Data;
import java.util.Date;
@Data
public class AuthorizeVo {
private String username;
private String role;
private String token;
private Date expireTime;
}

View File

@ -0,0 +1,60 @@
package com.example.filter;
import com.auth0.jwt.interfaces.DecodedJWT;
import com.example.utils.JwtUtils;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Mono;
@Slf4j
@Component
@RequiredArgsConstructor
public class JwtAuthenticationFilter implements WebFilter {
private final JwtUtils jwtUtils;
@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
// 获取Authorization头
String authHeader = exchange.getRequest().getHeaders().getFirst("Authorization");
// 如果没有Authorization头或者不是需要认证的路径直接放行
if (authHeader == null || shouldSkipAuthentication(exchange)) {
return chain.filter(exchange);
}
// 解析JWT令牌
DecodedJWT jwt = jwtUtils.resolveJwt(authHeader);
if (jwt != null) {
// 从JWT中提取用户信息
UserDetails userDetails = jwtUtils.toUser(jwt);
// 创建认证对象
UsernamePasswordAuthenticationToken authentication =
new UsernamePasswordAuthenticationToken(
userDetails, null, userDetails.getAuthorities());
// 将认证信息设置到SecurityContext中
return chain.filter(exchange)
.contextWrite(ReactiveSecurityContextHolder.withAuthentication(authentication));
}
// 如果认证失败继续过滤器链后续会由Spring Security的认证失败处理器处理
return chain.filter(exchange);
}
private boolean shouldSkipAuthentication(ServerWebExchange exchange) {
String path = exchange.getRequest().getURI().getPath();
// 不需要认证的路径
return path.startsWith("/chat/auth/") ||
path.equals("/chat/auth/login") ||
path.equals("/chat/auth/register") ||
exchange.getRequest().getMethod().name().equals("OPTIONS");
}
}

View File

@ -1,81 +0,0 @@
package com.example.handler;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.util.concurrent.GlobalEventExecutor;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class ChatServerHandler extends SimpleChannelInboundHandler<String> {
// 维护所有已连接的客户端 Channel
private static final ChannelGroup channels = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
@Override
public void channelActive(ChannelHandlerContext ctx) {
Channel incoming = ctx.channel();
log.info("Client connected: {}", incoming.remoteAddress());
// 广播通知所有在线用户
channels.writeAndFlush("[Server] - " + incoming.remoteAddress() + " joined\n");
channels.add(incoming);
}
@Override
protected void channelRead0(ChannelHandlerContext ctx,
String msg) throws Exception {
Channel sender = ctx.channel();
log.info("Received message from {} : {}", sender.remoteAddress(), msg);
if (msg.startsWith("@")) {
String[] parts = msg.split(":", 2);
if (parts.length == 2) {
String targetAddress = parts[0].substring(1);
String privateMessage = parts[1];
// 发送私聊消息
sendPrivateMessage(sender, targetAddress, privateMessage);
return;
}
}
// 广播消息给所有客户端排除发送者
for (Channel channel : channels) {
if (channel != sender) {
channel.writeAndFlush("[" + sender.remoteAddress() + "] " + msg + "\n");
}
}
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
Channel outgoing = ctx.channel();
log.info("Client disconnected: {}", outgoing.remoteAddress());
// ChannelGroup 移除断开的客户端
// 广播通知所有在线用户
channels.writeAndFlush("[Server] - " + outgoing.remoteAddress() + " left\n");
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
log.error("Error occurred: ", cause);
ctx.close();
}
// 私聊
private void sendPrivateMessage(Channel sender, String targetAddress, String message) {
for (Channel channel : channels) {
if (channel.remoteAddress().toString().contains(targetAddress)) {
channel.writeAndFlush("[PRIVATE] From [" + sender.remoteAddress() + "]: " + message + "\n");
sender.writeAndFlush("[PRIVATE] To [" + targetAddress + "]: " + message + "\n");
return;
}
}
sender.writeAndFlush("[SERVER] - User " + targetAddress + " not found.\n");
}
}

View File

@ -1,36 +0,0 @@
package com.example.handler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class ServerHandler extends SimpleChannelInboundHandler<String> {
@Override
public void channelActive(ChannelHandlerContext ctx) {
// 当客户端连接时触发记录客户端的远程地址
log.info("Client connected: {}", ctx.channel().remoteAddress());
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, String msg) {
// 处理接收到的消息日志记录接收到的消息内容
log.info("Received message: {}", msg);
// 发送响应消息给客户端
ctx.writeAndFlush("Server received: " + msg + "\n");
}
@Override
public void channelInactive(ChannelHandlerContext ctx) {
// 当客户端断开连接时触发记录客户端的远程地址
log.info("Client disconnected: {}", ctx.channel().remoteAddress());
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
// 处理异常记录错误信息并关闭连接
log.error("Error occurred: ", cause);
ctx.close();
}
}

View File

@ -0,0 +1,16 @@
package com.example.repository;
import com.example.entity.dto.AccountDTO;
import org.springframework.data.r2dbc.repository.Query;
import org.springframework.data.repository.reactive.ReactiveCrudRepository;
import reactor.core.publisher.Mono;
public interface AccountRepository extends ReactiveCrudRepository<AccountDTO, Long> {
/**
* 通过用户名或邮箱查找账户
*/
@Query("SELECT * FROM db_account WHERE username = :username OR email = :email LIMIT 1")
Mono<AccountDTO> findByUsernameOrEmail(String username, String email);
}

View File

@ -1,47 +0,0 @@
package com.example.server;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import jakarta.annotation.PreDestroy;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
@Slf4j
@Component
public class NettyServer {
private final EventLoopGroup bossGroup = new NioEventLoopGroup(1);
private final EventLoopGroup workerGroup = new NioEventLoopGroup();
public void start(int port) {
try {
ServerBootstrap bootstrap = new ServerBootstrap();
bootstrap.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.option(ChannelOption.SO_BACKLOG, 128)
.childOption(ChannelOption.SO_KEEPALIVE, true)
.childHandler(new ServerInitializer());
ChannelFuture future = bootstrap.bind(port).sync();
Channel serverChannel = future.channel();
log.info("Server started on port {}", port);
serverChannel.closeFuture().sync();
} catch (InterruptedException e) {
log.error("Netty Chat Server interrupted!", e);
Thread.currentThread().interrupt();
} finally {
shutdown();
}
}
@PreDestroy
public void shutdown() {
log.info("Shutting down Netty Chat Server...");
bossGroup.shutdownGracefully();
workerGroup.shutdownGracefully();
}
}

View File

@ -1,18 +0,0 @@
package com.example.server;
import com.example.config.NettyProperties;
import jakarta.annotation.PostConstruct;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Component;
@Component
@RequiredArgsConstructor
public class NettyServerStarter {
private final NettyServer nettyServer;
private final NettyProperties properties;
@PostConstruct
public void startServer() {
new Thread(() -> nettyServer.start(properties.getServer().getPort())).start();
}
}

View File

@ -1,21 +0,0 @@
package com.example.server;
import com.example.handler.ChatServerHandler;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.string.StringDecoder;
import io.netty.handler.codec.string.StringEncoder;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
public class ServerInitializer extends ChannelInitializer<SocketChannel> {
@Override
protected void initChannel(SocketChannel ch) {
ch.pipeline().addLast(
new LoggingHandler(LogLevel.INFO), // 日志
new StringDecoder(), // 解码
new StringEncoder(), // 编码
new ChatServerHandler()
);
}
}

View File

@ -0,0 +1,16 @@
package com.example.service;
import com.example.entity.dto.AccountDTO;
import reactor.core.publisher.Mono;
public interface AccountService {
/**
* 根据用户名或邮箱查找账户
*
* @param text 用户名或邮箱
* @return 返回包含账户信息的响应式 Mono
*/
Mono<AccountDTO> findAccountByNameOrEmail(String text);
}

View File

@ -0,0 +1,41 @@
package com.example.service.impl;
import com.example.entity.AccountDetails;
import com.example.entity.dto.AccountDTO;
import com.example.repository.AccountRepository;
import com.example.service.AccountService;
import org.springframework.security.core.userdetails.ReactiveUserDetailsService;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UsernameNotFoundException;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Mono;
@Service
public class AccountServiceImpl implements AccountService, ReactiveUserDetailsService {
private final AccountRepository accountRepository;
public AccountServiceImpl(AccountRepository accountRepository) {
this.accountRepository = accountRepository;
}
/**
* 实现 ReactiveUserDetailsService 接口的方法
*/
@Override
public Mono<UserDetails> findByUsername(String username) {
return findAccountByNameOrEmail(username)
.switchIfEmpty(Mono.error(new UsernameNotFoundException("此账号未注册")))
.map(accountDTO -> AccountDetails
.withAccount(accountDTO)
.build());
}
/**
* 实现 AccountService 接口的方法
*/
@Override
public Mono<AccountDTO> findAccountByNameOrEmail(String text) {
return accountRepository.findByUsernameOrEmail(text, text);
}
}

View File

@ -0,0 +1,5 @@
package com.example.utils;
public class Const {
public static final String JWT_BLACK_LIST = "jwt:blacklist:";
}

View File

@ -0,0 +1,196 @@
package com.example.utils;
import com.auth0.jwt.JWT;
import com.auth0.jwt.JWTVerifier;
import com.auth0.jwt.algorithms.Algorithm;
import com.auth0.jwt.exceptions.JWTVerificationException;
import com.auth0.jwt.interfaces.Claim;
import com.auth0.jwt.interfaces.DecodedJWT;
import jakarta.annotation.Resource;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.stereotype.Component;
import java.util.Calendar;
import java.util.Date;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
@Component
public class JwtUtils {
// JWT的秘钥
@Value("${spring.security.jwt.key}")
private String key;
// JWT的过期时间天数
@Value("${spring.security.jwt.expire}")
private int expire;
@Resource
private StringRedisTemplate template;
/**
* 生成JWT令牌
*
* @param details 用户详情
* @param id 用户ID
* @param username 用户名
* @return 生成的JWT字符串
*/
public String generateJwt(UserDetails details, int id,
String username) {
Algorithm algorithm = Algorithm.HMAC256(key);
return JWT.create()
.withJWTId(UUID.randomUUID().toString())
.withClaim("id", id)
.withClaim("username", username)
.withClaim("authorities", details.getAuthorities()
.stream().map(GrantedAuthority::getAuthority).toList()
)
.withIssuedAt(new Date())
.sign(algorithm);
}
/**
* 获取JWT的过期时间
*
* @return 过期时间
*/
public Date expireTime() {
Calendar calendar = Calendar.getInstance();
// 计算过期时间天数转换为小时
calendar.add(Calendar.HOUR, expire * 24);
return calendar.getTime();
}
/**
* 使JWT失效
*
* @param headerToken 请求头中的JWT令牌
* @return 是否成功使JWT失效
*/
public boolean invalidateJwt(String headerToken) {
String token = this.convertToken(headerToken);
if (token == null) return false;
Algorithm algorithm = Algorithm.HMAC256(key);
JWTVerifier jwtVerifier = JWT.require(algorithm).build();
try {
DecodedJWT jwt = jwtVerifier.verify(token);
String id = jwt.getId();
return this.deleteToken(id, jwt.getExpiresAt());
} catch (JWTVerificationException e) {
return false;
}
}
/**
* 将JWT的ID和过期时间存入Redis黑名单
* @param uuid JWT的ID
* @param expireTime JWT的过期时间
* @return 是否删除了Token
*/
private boolean deleteToken(String uuid, Date expireTime) {
if (this.invalidateJwt(uuid))
return false;
long expire = this.calculateExpireTime(expireTime);
this.addTokenToBlacklist(uuid, expire);
return true;
}
/**
* 计算JWT的剩余有效时间
* @param expireTime JWT的过期时间
* @return 剩余的有效时间单位毫秒
*/
private long calculateExpireTime(Date expireTime) {
Date now = new Date();
return Math.max(expireTime.getTime() - now.getTime(), 0);
}
/**
* 将JWT的ID和过期时间存入Redis黑名单
* @param uuid JWT的ID用于标识黑名单中的条目
* @param expire JWT的剩余有效时间单位毫秒
* 该方法将JWT的ID与过期时间存入Redis黑名单
* 使得黑名单中的JWT在过期后自动失效
*/
private void addTokenToBlacklist(String uuid, long expire) {
template.opsForValue().set(
Const.JWT_BLACK_LIST + uuid,
"", expire, TimeUnit.MILLISECONDS
);
}
/**
* 判断JWT是否在黑名单中
* @param uuid JWT的ID
* @return 是否在黑名单中
*/
private boolean isInvalidJwt(String uuid) {
return template.hasKey(Const.JWT_BLACK_LIST + uuid);
}
/**
* 解析JWT并验证其有效性
*
* @param headerToken 请求头中的JWT令牌
* @return 解析后的DecodedJWT对象如果无效返回null
*/
public DecodedJWT resolveJwt(String headerToken) {
String token = this.convertToken(headerToken);
if (token == null) return null;
Algorithm algorithm = Algorithm.HMAC256(key);
JWTVerifier jwtVerifier = JWT.require(algorithm).build();
try {
DecodedJWT jwt = jwtVerifier.verify(token);
if (this.isInvalidJwt(jwt.getId()))
return null;
Date expiresAt = jwt.getExpiresAt();
return new Date().after(expiresAt) ? null : jwt;
} catch (JWTVerificationException exception) {
return null;
}
}
/**
* 转换请求头中的JWT令牌
* @param headerToken 请求头中的JWT令牌
* @return JWT内容信息
*/
private String convertToken(String headerToken) {
// 校验是否为Bearer令牌
if (headerToken == null || !headerToken.startsWith("Bearer "))
return null;
return headerToken.substring(7);
}
/**
* 从JWT中提取用户信息
*
* @param jwt 解析后的JWT对象
* @return Spring Security的UserDetails对象
*/
public UserDetails toUser(DecodedJWT jwt) {
// 获取JWT的声明部分
Map<String, Claim> claims = jwt.getClaims();
return User
.withUsername(claims.get("username").asString())
.password("******")
.authorities(claims.get("authorities").asArray(String.class))
.build();
}
/**
* 从JWT中提取用户ID
*
* @param jwt 解析后的JWT对象
* @return 用户ID
*/
public Integer toId(DecodedJWT jwt) {
Map<String, Claim> claims = jwt.getClaims();
return claims.get("id").asInt();
}
}