001/*
002 * Copyright 2011 Atteo.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package org.atteo.moonshine.tests;
017
018import java.sql.Connection;
019import java.sql.DatabaseMetaData;
020import java.sql.ResultSet;
021import java.sql.SQLException;
022import java.sql.Statement;
023import java.util.ArrayList;
024import java.util.List;
025
026import javax.sql.DataSource;
027
028import org.slf4j.Logger;
029import org.slf4j.LoggerFactory;
030
031public class DatabaseCleaner {
032    private static final Logger logger = LoggerFactory.getLogger(DatabaseCleaner.class);
033
034    public static void clean(DataSource dataSource) {
035        logger.debug("Clearing database");
036        try (Connection connection = dataSource.getConnection()) {
037
038            List<String> tables = analyseDatabase(connection);
039
040            clearTables(connection, tables);
041        } catch (SQLException e) {
042            throw new RuntimeException(e);
043        }
044    }
045
046    private static List<String> analyseDatabase(Connection connection) {
047        try {
048            List<String> tables = new ArrayList<>();
049
050            DatabaseMetaData metaData = connection.getMetaData();
051
052            try (ResultSet result = metaData.getTables(null, null, "%", new String[]{"TABLE"})) {
053                while (result.next()) {
054                    String tableName = result.getString("TABLE_NAME");
055                    if (!tableName.equals("DATABASECHANGELOG") && !tableName.equals("DATABASECHANGELOGLOCK")) {
056                        tables.add(tableName);
057                    }
058                }
059            }
060
061            return tables;
062        } catch (SQLException e) {
063            throw new RuntimeException("An exception occurred while trying to analyse the database.", e);
064        }
065    }
066
067    private static void clearTables(Connection connection, List<String> tables) {
068        for (String table : tables) {
069            clearSingleTable(connection, table);
070        }
071    }
072
073    private static void clearSingleTable(Connection connection, String tableName) {
074        try (Statement statement = connection.createStatement()) {
075            statement.executeUpdate("DELETE FROM " + tableName);
076        } catch (SQLException ex) {
077            throw new RuntimeException("Can't read table contents from table ".concat(tableName), ex);
078        }
079    }
080}