001/*
002 * Copyright 2011 Atteo.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package org.atteo.moonshine.tests;
017
018import java.sql.Connection;
019import java.sql.DatabaseMetaData;
020import java.sql.ResultSet;
021import java.sql.SQLException;
022import java.sql.Statement;
023import java.util.ArrayList;
024import java.util.List;
025
026import javax.sql.DataSource;
027
028import org.atteo.moonshine.database.DatabaseService;
029import org.slf4j.Logger;
030import org.slf4j.LoggerFactory;
031
032
033public class DatabaseCleaner {
034
035    private static final Logger logger = LoggerFactory.getLogger(DatabaseCleaner.class);
036
037    private final DataSource dataSource;
038
039    private final DatabaseService database;
040
041    public DatabaseCleaner(DataSource dataSource, DatabaseService database) {
042        this.dataSource = dataSource;
043        this.database = database;
044    }
045
046    /**
047     * Restore the database to its pristine state (after all migrations have run).
048     */
049    public void reset() {
050        dropTables();
051        database.executeMigrations(dataSource);
052    }
053
054    /**
055     * Clean all database tables.
056     */
057    public void clean() {
058        logger.debug("Clearing database");
059        try (Connection connection = dataSource.getConnection()) {
060
061            List<String> tables = analyseDatabase(connection);
062
063            clearTables(connection, tables);
064        } catch (SQLException e) {
065            throw new RuntimeException(e);
066        }
067    }
068
069    /**
070     * Drop all database tables.
071     */
072    public void dropTables() {
073        try (Connection connection = dataSource.getConnection()) {
074            List<String> tables = analyseDatabase(connection);
075
076            dropTables(connection, tables);
077        } catch (SQLException e) {
078            throw new RuntimeException(e);
079        }
080
081    }
082
083    private List<String> analyseDatabase(Connection connection) {
084        try {
085            List<String> tables = new ArrayList<>();
086
087            DatabaseMetaData metaData = connection.getMetaData();
088
089            try (ResultSet result = metaData.getTables(null, null, "%", new String[]{"TABLE"})) {
090                while (result.next()) {
091                    String tableName = result.getString("TABLE_NAME");
092                    tables.add(tableName);
093                }
094            }
095
096            return tables;
097        } catch (SQLException e) {
098            throw new RuntimeException("An exception occurred while trying to analyse the database.", e);
099        }
100    }
101
102    private void clearTables(Connection connection, List<String> tables) {
103        for (String table : tables) {
104            if (!table.equals("DATABASECHANGELOG") && !table.equals("DATABASECHANGELOGLOCK")) {
105                clearSingleTable(connection, table);
106            }
107        }
108
109    }
110
111
112    private void clearSingleTable(Connection connection, String tableName) {
113        try (Statement statement = connection.createStatement()) {
114            statement.executeUpdate("DELETE FROM " + tableName);
115        } catch (SQLException ex) {
116            throw new RuntimeException("Can't read table contents from table ".concat(tableName), ex);
117        }
118    }
119
120    private void dropTables(Connection connection, List<String> tables) {
121        for (String table : tables) {
122            dropTable(connection, table);
123        }
124    }
125
126    private void dropTable(Connection connection, String tableName) {
127        try (Statement statement = connection.createStatement()) {
128            statement.executeUpdate("DROP TABLE " + tableName);
129        } catch (SQLException ex) {
130            throw new RuntimeException("Can't read table contents from table ".concat(tableName), ex);
131        }
132    }
133}