aboutsummaryrefslogtreecommitdiffstats
path: root/bot/utils/persist.py
diff options
context:
space:
mode:
Diffstat (limited to 'bot/utils/persist.py')
-rw-r--r--bot/utils/persist.py46
1 files changed, 37 insertions, 9 deletions
diff --git a/bot/utils/persist.py b/bot/utils/persist.py
index 35e1e41a..939a95c9 100644
--- a/bot/utils/persist.py
+++ b/bot/utils/persist.py
@@ -2,10 +2,12 @@ import sqlite3
from pathlib import Path
from shutil import copyfile
+from bot.seasons.season import get_seasons
+
DIRECTORY = Path("data") # directory that has a persistent volume mapped to it
-def datafile(file_path: Path) -> Path:
+def make_persistent(file_path: Path) -> Path:
"""
Copy datafile at the provided file_path to the persistent data directory.
@@ -19,22 +21,48 @@ def datafile(file_path: Path) -> Path:
If the persistent file already exists, it won't be overwritten with the
clean default file, just returning the Path instead to the existing file.
+ Note: Avoid using the same file name as other features in the same seasons
+ as otherwise only one datafile can be persistent and will be returned for
+ both cases.
+
Example Usage:
- >>> clean_default_datafile = Path("bot", "resources", "datafile.json")
- >>> persistent_file_path = datafile(clean_default_datafile)
+ >>> import json
+ >>> template_datafile = Path("bot", "resources", "evergreen", "myfile.json")
+ >>> path_to_persistent_file = make_persistent(template_datafile)
+ >>> print(path_to_persistent_file)
+ data/evergreen/myfile.json
+ >>> with path_to_persistent_file.open("w+") as f:
+ >>> data = json.load(f)
"""
+ # ensure the persistent data directory exists
+ if not DIRECTORY.exists():
+ DIRECTORY.mkdir()
+
if not file_path.is_file():
raise OSError(f"File not found at {file_path}.")
- persistant_path = Path(DIRECTORY, file_path.name)
+ # detect season in datafile path for assigning to subdirectory
+ season = next((s for s in get_seasons() if s in file_path.parts), None)
+
+ if season:
+ # make sure subdirectory exists first
+ subdirectory = Path(DIRECTORY, season)
+ if not subdirectory.exists():
+ subdirectory.mkdir()
+
+ persistent_path = Path(subdirectory, file_path.name)
+
+ else:
+ persistent_path = Path(DIRECTORY, file_path.name)
- if not persistant_path.exists():
- copyfile(file_path, persistant_path)
+ # copy base/template datafile to persistent directory
+ if not persistent_path.exists():
+ copyfile(file_path, persistent_path)
- return persistant_path
+ return persistent_path
def sqlite(db_path: Path) -> sqlite3.Connection:
"""Copy sqlite file to the persistent data directory and return an open connection."""
- persistant_path = datafile(db_path)
- return sqlite3.connect(persistant_path)
+ persistent_path = make_persistent(db_path)
+ return sqlite3.connect(persistent_path)