JPA 表租户 SQL解析实现

1. 功能介绍

  • 针对表租户ID字段标识的多租户系统

  • 参考了Mybatis-Plus插件的TenantSqlParser进行的JPA实现,使用jsqlparser解析并修改SQL,我们不生产代码,我们只做代码的搬运工

  • 实现获取当前用户租户ID,SQL增删改查时处理租户字段,实现租户数据的隔离
    参考项目:

  • https://github.com/baomidou/mybatis-plus

  • https://github.com/JSQLParser/JSqlParser

2. 在JPA项目中引入jsqlparser依赖,本例中使用的版本号为3.1

1
2
3
4
5
            <dependency>
                <groupId>com.github.jsqlparser</groupId>
                <artifactId>jsqlparser</artifactId>
                <version>${jsqlparser.version}</version>
            </dependency>

3. 编写租户拦截器TenantInterceptor

重写hibernate提供的StatementInspector的inspect接口,参数为hibernate处理后的原始SQL,返回值为我们修改后的SQL

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
import lombok.Data;
import lombok.experimental.Accessors;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.BinaryExpression;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.*;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.Statements;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.statement.update.Update;
import org.hibernate.resource.jdbc.spi.StatementInspector;

import java.util.List;

/**
 * 参考Mybatis-Plus插件中的TenantSqlParser进行租户解析处理,其实现为使用jsqlparser对sql进行解析,拼装SQL语句
 *
 * @author wangqichang
 * @since 2019/12/5
 */
@Slf4j
@Data
@Accessors(chain = true)
public class TenantInterceptor implements StatementInspector {


    /**
     * 当前租户ID,从UserContext获取
     */
    private String tenantId;

    /**
     * 需进行租户解析的表名,需要注入
     */
    private List<String> tenantTables;

    /**
     * 需进行租户解析的租户字段名,本项目中为固定名称
     */
    private String tenantIdColumn = "tenant_id";


    /**
     * 重写StatementInspector的inspect接口,参数为hibernate处理后的原始SQL,返回值为我们修改后的SQL
     * @param sql
     * @return
     */
    @Override
    public String inspect(String sql) {
        try {
            /**
             * 非租户用户不进行解析
             */
            if (UserContext.current() == null || UserContext.current().getAdministrator()) {
                return null;
            }
            /**
             * 初始化需要进行租户解析的租户表
             */
            if (tenantTables == null) {
                TenantProperties bean = SpringContextUtil.getBean(TenantProperties.class);
                if (bean != null) {
                    tenantTables = bean.getTables();
                } else {
                    throw new RuntimeException("未能获取TenantProperties参数配置");
                }
            }

            /**
             * 从当前线程获取登录用户的所属租户ID
             */
            CurrentUser user = UserContext.current();
            tenantId = user.getTenantId();

            log.info("租户解析开始,原始SQL:{}", sql);
            Statements statements = CCJSqlParserUtil.parseStatements(sql);
            StringBuilder sqlStringBuilder = new StringBuilder();
            int i = 0;
            for (Statement statement : statements.getStatements()) {
                if (null != statement) {
                    if (i++ > 0) {
                        sqlStringBuilder.append(';');
                    }
                    sqlStringBuilder.append(this.processParser(statement));
                }
            }
            String newSql = sqlStringBuilder.toString();
            log.info("租户解析结束,解析后SQL:{}", newSql);
            return newSql;
        } catch (Exception e) {
            log.error("租户解析失败,解析SQL异常{}", e.getMessage());
            e.printStackTrace();
        } finally {
            tenantId = null;
        }
        return null;
    }

    private String processParser(Statement statement) {
        if (statement instanceof Insert) {
            this.processInsert((Insert) statement);
        } else if (statement instanceof Select) {
            this.processSelectBody(((Select) statement).getSelectBody());
        } else if (statement instanceof Update) {
            this.processUpdate((Update) statement);
        } else if (statement instanceof Delete) {
            this.processDelete((Delete) statement);
        }
        /**
         * 返回处理后的SQL
         */
        return statement.toString();
    }

    /**
     * select 语句处理
     */

    public void processSelectBody(SelectBody selectBody) {
        if (selectBody instanceof PlainSelect) {
            processPlainSelect((PlainSelect) selectBody);
        } else if (selectBody instanceof WithItem) {
            WithItem withItem = (WithItem) selectBody;
            if (withItem.getSelectBody() != null) {
                processSelectBody(withItem.getSelectBody());
            }
        } else {
            SetOperationList operationList = (SetOperationList) selectBody;
            if (operationList.getSelects() != null && operationList.getSelects().size() > 0) {
                operationList.getSelects().forEach(this::processSelectBody);
            }
        }
    }

    /**
     * insert 语句处理
     */

    public void processInsert(Insert insert) {
        if (tenantTables.contains(insert.getTable().getFullyQualifiedName())) {
            insert.getColumns().add(new Column(tenantIdColumn));
            if (insert.getSelect() != null) {
                processPlainSelect((PlainSelect) insert.getSelect().getSelectBody(), true);
            } else if (insert.getItemsList() != null) {
                // fixed github pull/295
                ItemsList itemsList = insert.getItemsList();
                if (itemsList instanceof MultiExpressionList) {
                    ((MultiExpressionList) itemsList).getExprList().forEach(el -> el.getExpressions().add(new StringValue(tenantId)));
                } else {
                    ((ExpressionList) insert.getItemsList()).getExpressions().add(new StringValue(tenantId));
                }
            } else {
                throw new RuntimeException("Failed to process multiple-table update, please exclude the tableName or statementId");
            }
        }
    }

    /**
     * update 语句处理
     */

    public void processUpdate(Update update) {
        final Table table = update.getTable();
        if (tenantTables.contains(table.getFullyQualifiedName())) {
            update.setWhere(this.andExpression(table, update.getWhere()));
        }
    }

    /**
     * delete 语句处理
     */

    public void processDelete(Delete delete) {
        if (tenantTables.contains(delete.getTable().getFullyQualifiedName())) {
            delete.setWhere(this.andExpression(delete.getTable(), delete.getWhere()));
        }
    }

    /**
     * delete update 语句 where 处理
     */
    protected BinaryExpression andExpression(Table table, Expression where) {
        //获得where条件表达式
        EqualsTo equalsTo = new EqualsTo();
        equalsTo.setLeftExpression(this.getAliasColumn(table));
        equalsTo.setRightExpression(new StringValue(tenantId));
        if (null != where) {
            if (where instanceof OrExpression) {
                return new AndExpression(equalsTo, new Parenthesis(where));
            } else {
                return new AndExpression(equalsTo, where);
            }
        }
        return equalsTo;
    }

    /**
     * 处理 PlainSelect
     */
    protected void processPlainSelect(PlainSelect plainSelect) {
        processPlainSelect(plainSelect, false);
    }

    /**
     * 处理 PlainSelect
     *
     * @param plainSelect ignore
     * @param addColumn   是否添加租户列,insert into select语句中需要
     */
    protected void processPlainSelect(PlainSelect plainSelect, boolean addColumn) {
        FromItem fromItem = plainSelect.getFromItem();
        if (fromItem instanceof Table) {
            Table fromTable = (Table) fromItem;
            if (tenantTables.contains(fromTable.getFullyQualifiedName())) {
                //#1186 github
                plainSelect.setWhere(builderExpression(plainSelect.getWhere(), fromTable));
                if (addColumn) {
                    plainSelect.getSelectItems().add(new SelectExpressionItem(
                            new Column(tenantIdColumn)));
                }
            }
        } else {
            processFromItem(fromItem);
        }
        List<Join> joins = plainSelect.getJoins();
        if (joins != null && joins.size() > 0) {
            joins.forEach(j -> {
                processJoin(j);
                processFromItem(j.getRightItem());
            });
        }
    }

    /**
     * 处理子查询等
     */
    protected void processFromItem(FromItem fromItem) {
        if (fromItem instanceof SubJoin) {
            SubJoin subJoin = (SubJoin) fromItem;
            if (subJoin.getJoinList() != null) {
                subJoin.getJoinList().forEach(this::processJoin);
            }
            if (subJoin.getLeft() != null) {
                processFromItem(subJoin.getLeft());
            }
        } else if (fromItem instanceof SubSelect) {
            SubSelect subSelect = (SubSelect) fromItem;
            if (subSelect.getSelectBody() != null) {
                processSelectBody(subSelect.getSelectBody());
            }
        } else if (fromItem instanceof ValuesList) {
            log.debug("Perform a subquery, if you do not give us feedback");
        } else if (fromItem instanceof LateralSubSelect) {
            LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem;
            if (lateralSubSelect.getSubSelect() != null) {
                SubSelect subSelect = lateralSubSelect.getSubSelect();
                if (subSelect.getSelectBody() != null) {
                    processSelectBody(subSelect.getSelectBody());
                }
            }
        }
    }

    /**
     * 处理联接语句
     */
    protected void processJoin(Join join) {
        if (join.getRightItem() instanceof Table) {
            Table fromTable = (Table) join.getRightItem();
            if (tenantTables.contains(fromTable.getFullyQualifiedName())) {
                join.setOnExpression(builderExpression(join.getOnExpression(), fromTable));
            }
        }
    }

    /**
     * 处理条件:
     * 支持 getTenantHandler().getTenantId()是一个完整的表达式:tenant in (1,2)
     * 默认tenantId的表达式: LongValue(1)这种依旧支持
     */
    protected Expression builderExpression(Expression currentExpression, Table table) {
        final Expression tenantExpression = new StringValue(tenantId);
        Expression appendExpression;
        if (!(tenantExpression instanceof SupportsOldOracleJoinSyntax)) {
            appendExpression = new EqualsTo();
            ((EqualsTo) appendExpression).setLeftExpression(this.getAliasColumn(table));
            ((EqualsTo) appendExpression).setRightExpression(tenantExpression);
        } else {
            appendExpression = processTableAlias4CustomizedTenantIdExpression(tenantExpression, table);
        }
        if (currentExpression == null) {
            return appendExpression;
        }
        if (currentExpression instanceof BinaryExpression) {
            BinaryExpression binaryExpression = (BinaryExpression) currentExpression;
            doExpression(binaryExpression.getLeftExpression());
            doExpression(binaryExpression.getRightExpression());
        } else if (currentExpression instanceof InExpression) {
            InExpression inExp = (InExpression) currentExpression;
            ItemsList rightItems = inExp.getRightItemsList();
            if (rightItems instanceof SubSelect) {
                processSelectBody(((SubSelect) rightItems).getSelectBody());
            }
        }
        if (currentExpression instanceof OrExpression) {
            return new AndExpression(new Parenthesis(currentExpression), appendExpression);
        } else {
            return new AndExpression(currentExpression, appendExpression);
        }
    }

    protected void doExpression(Expression expression) {
        if (expression instanceof FromItem) {
            processFromItem((FromItem) expression);
        } else if (expression instanceof InExpression) {
            InExpression inExp = (InExpression) expression;
            ItemsList rightItems = inExp.getRightItemsList();
            if (rightItems instanceof SubSelect) {
                processSelectBody(((SubSelect) rightItems).getSelectBody());
            }
        }
    }

    /**
     * 目前: 针对自定义的tenantId的条件表达式[tenant_id in (1,2,3)],无法处理多租户的字段加上表别名
     * select a.id, b.name
     * from a
     * join b on b.aid = a.id and [b.]tenant_id in (1,2) --别名[b.]无法加上 TODO
     *
     * @param expression
     * @param table
     * @return 加上别名的多租户字段表达式
     */
    protected Expression processTableAlias4CustomizedTenantIdExpression(Expression expression, Table table) {
        //cannot add table alias for customized tenantId expression,
        // when tables including tenantId at the join table poistion
        return expression;
    }

    /**
     * 租户字段别名设置
     * <p>tableName.tenantId 或 tableAlias.tenantId</p>
     *
     * @param table 表对象
     * @return 字段
     */
    protected Column getAliasColumn(Table table) {
        StringBuilder column = new StringBuilder();
        if (null == table.getAlias()) {
            column.append(table.getName());
        } else {
            column.append(table.getAlias().getName());
        }
        column.append(".");
        column.append(tenantIdColumn);
        return new Column(column.toString());
    }

}

4.JPA拦截yml配置

1
2
3
4
5
6
7
8
9
10
spring:
  jpa:
    database: mysql
    show-sql: true
    hibernate:
      ddl-auto: update
    properties:
      hibernate:
        session_factory:
          statement_inspector: com.tba.sc.common.intercepters.TenantInterceptor

5. 租户表yml配置

1
2
3
4
# 需进行租户解析的租户表
tenant:
  tables:
    - sys_user

6. 租户表配置类

1
2
3
4
5
6
7
8
9
10
@Data
@Component
@ConfigurationProperties(prefix = "tenant")
public class TenantProperties {

    /**
     * 需要进行租户解析的租户表
     */
    private List<String> tables;
}

7. 测试类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
/**
 * @author wangqichang
 * @since 2019/12/5
 */

@Slf4j
@SpringBootTest(classes = SystemApplication.class)
@RunWith(SpringRunner.class)
public class TanentTest {

    @Autowired
    UserService userService;

    @Test
    public void tenantTest() {
        CurrentUser user = new CurrentUser();
        user.setId("40285b816252ff61016253008f9f0000");
        user.setTenantId("40285b816252ff61016253008f9f0001");
        user.setAdministrator(false);
        UserContext.setCurrentUser(user);
        while (true) {
            List<UserDTO> all = userService.findAll();
            all.forEach(x -> log.info(x.toString()));
        }
    }
}

8. 测试效果如下

可以看到查询的SQL语句自动拼接了WHERE user0_.tenant_id = '40285b816252ff61016253008f9f0001'条件

1
2
2019-12-06 10:02:22.345  INFO 174116 --- [           main] c.t.s.c.intercepters.TenantInterceptor   : 租户解析开始,原始SQL:select user0_.id as id1_9_, user0_.create_date as create_d2_9_, user0_.update_date as update_d3_9_, user0_.administrator as administ4_9_, user0_.org_id as org_id10_9_, user0_.password as password5_9_, user0_.real_name as real_nam6_9_, user0_.salt as salt7_9_, user0_.tenant_id as tenant_i8_9_, user0_.user_name as user_nam9_9_ from sys_user user0_
2019-12-06 10:02:22.348  INFO 174116 --- [           main] c.t.s.c.intercepters.TenantInterceptor   : 租户解析结束,解析后SQL:SELECT user0_.id AS id1_9_, user0_.create_date AS create_d2_9_, user0_.update_date AS update_d3_9_, user0_.administrator AS administ4_9_, user0_.org_id AS org_id10_9_, user0_.password AS password5_9_, user0_.real_name AS real_nam6_9_, user0_.salt AS salt7_9_, user0_.tenant_id AS tenant_i8_9_, user0_.user_name AS user_nam9_9_ FROM sys_user user0_ WHERE user0_.tenant_id = '40285b816252ff61016253008f9f0001'

9. 说明:

  • 关于UserContext,此类为自定义的当前用户上下文,各位需要自己实现,原理为从会话中获取当前操作用户的租户ID

10.请点赞