Commit abc5fa40 authored by Marta Różańska's avatar Marta Różańska
Browse files

Merge branch 'tft_nbeats' into 'morphemic-rc1.5'

waiting for dataset logs corrected

See merge request !136
parents 3af71835 09f29e1f
......@@ -5,11 +5,15 @@ import time
from filelock import FileLock
import pytorch_lightning as pl
import os
import logging
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_forecasting.metrics import QuantileLoss, MAE, RMSE, CrossEntropy
from pytorch_forecasting import NBeats
from src.preprocess_dataset import Dataset
logging.basicConfig(
filename=f"/logs/{os.environ.get('METHOD', 'nbeats')}.out", level=logging.INFO
)
"""Script for temporal fusion transformer training"""
......@@ -36,6 +40,9 @@ def train(target_column, prediction_length, yaml_file="model.yaml"):
dataset = pd.read_csv(data_path).tail(1000)
if dataset.shape[0] < 12 * prediction_length:
logging.info(
f"dataset len: {dataset.shape[0]}, minimum points required: {12 * prediction_length}"
)
return None
ts_dataset = Dataset(dataset, target_column=target_column, **params["dataset"])
......
......@@ -5,11 +5,15 @@ import time
from filelock import FileLock
import pytorch_lightning as pl
import os
import logging
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_forecasting.metrics import QuantileLoss, MAE, RMSE, CrossEntropy
from pytorch_forecasting import TemporalFusionTransformer
from src.preprocess_dataset import Dataset
logging.basicConfig(
filename=f"/logs/{os.environ.get('METHOD', 'tft')}.out", level=logging.INFO
)
"""Script for temporal fusion transformer training"""
......@@ -37,6 +41,9 @@ def train(target_column, prediction_length, yaml_file="model.yaml"):
dataset = pd.read_csv(data_path).tail(1000)
if dataset.shape[0] < 12 * prediction_length:
logging.info(
f"dataset len: {dataset.shape[0]}, minimum points required: {12 * prediction_length}"
)
return None
ts_dataset = Dataset(dataset, target_column=target_column, **params["dataset"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment