Commit 27174686 authored by Anna Warno's avatar Anna Warno
Browse files

lock added, arima horizon corrcted

parent 0669d546
import os
from multiprocessing import Pool
import stomp
import json
from amq_message_python_library import * # python amq-message-python-library
import logging
import time
from datetime import datetime
from pytz import timezone
from datetime import datetime
import time
AMQ_USER = os.environ.get("AMQ_USER", "admin")
AMQ_PASSWORD = os.environ.get("AMQ_PASSWORD", "admin")
......@@ -25,10 +24,7 @@ def run_process(args):
def start_training(metrics_to_predict):
processes = (("retrain.py", metrics_to_predict), ("predict.py", metrics_to_predict))
pool = Pool(processes=2)
pool.map(run_process, processes)
run_process(("predict.py", metrics_to_predict))
class StartListener(stomp.ConnectionListener):
......@@ -112,11 +108,11 @@ def main():
)
# msg1 = Msg()
# msg1.body = '[{"metric": "memory", "level": 3, "publish_rate": 30000}]'
# msg1.body = '[{"metric": "value", "level": 3, "publish_rate": 30000}]'
# msg2 = Msg()
# msg2.body = (
# "{"
# + f'"metrics": ["memory"],"timestamp": {int(time.time())}, "epoch_start": {int(time.time()) + 30}, "number_of_forward_predictions": 8,"prediction_horizon": 30'
# + f'"metrics": ["value"],"timestamp": {int(time.time())}, "epoch_start": {int(time.time()) + 30}, "number_of_forward_predictions": 8,"prediction_horizon": 30'
# + "}"
# )
......
......@@ -120,30 +120,24 @@ def main():
dataset_preprocessor.prepare_csv()
global time_0
for metric in predicted_metrics:
for i in range(number_of_forward_predictions[metric]):
prediction_msgs = predict(
metric,
prediction_length,
prediction_hor=prediction_horizon,
timestamp=time_0,
)
if i == (number_of_forward_predictions[metric] - 1):
print(
f"time_0 difference seconds {start_time + (i + 1) * prediction_horizon // 1000 - int(time.time())}"
)
prediction_msgs = predict(
metric,
prediction_length,
prediction_hor=prediction_horizon,
timestamp=time_0,
)
if prediction_msgs:
dest = f"{PRED_TOPIC_PREF}.{metric}"
if prediction_msgs:
dest = f"{PRED_TOPIC_PREF}.{metric}"
for message in prediction_msgs:
logging.info(
f"Sending predictions for {metric} metric TIME: {datetime.now(pytz.timezone(TZ)).strftime('%d/%m/%Y %H:%M:%S')}"
)
start_conn.send_to_topic(dest, message[metric])
influxdb_conn.send_to_influxdb(metric, message)
for message in prediction_msgs:
logging.info(
f"Sending predictions for {metric} metric TIME: {datetime.now(pytz.timezone(TZ)).strftime('%d/%m/%Y %H:%M:%S')}"
)
start_conn.send_to_topic(dest, message[metric])
influxdb_conn.send_to_influxdb(metric, message)
end_time = int(time.time())
print(f"sleeping {prediction_cycle - (end_time - start_time)} seconds")
time_0 = time_0 + prediction_cycle
time_to_wait = prediction_cycle - (end_time - start_time)
......
......@@ -30,6 +30,19 @@ def predict(
os.environ.get("DATA_PATH", "./"), f'{os.environ.get("APP_NAME", "demo")}.csv'
)
if not os.path.isfile(data_path):
logging.info("Dataset not found")
return None
dataset = pd.read_csv(data_path)
new_ts_dataset = Dataset(dataset, target_column=target_column, **params["dataset"])
if new_ts_dataset.dropped_recent_series: # series with recent data was too short
logging.info(
f"Not enough fresh data, unable to predict TIME: {datetime.now(pytz.timezone(TZ)).strftime('%d/%m/%Y %H:%M:%S')}"
)
print("Not enough fresh data, unable to predict TIME:")
return None
dataset = pd.read_csv(data_path)
ts_dataset = Dataset(dataset, target_column=target_column, **params["dataset"])
......@@ -72,5 +85,4 @@ def predict(
}
}
msgs.append(msg)
return msgs
......@@ -120,7 +120,8 @@ def train(target_column, prediction_length, yaml_file="model.yaml"):
with lock:
torch.save(model.state_dict(), model_path)
else:
torch.save(model.state_dict(), model_path)
with lock:
torch.save(model.state_dict(), model_path)
msg = {
"metrics": [target_column],
......
......@@ -121,7 +121,8 @@ def train(target_column, prediction_length, yaml_file="model.yaml"):
with lock:
torch.save(tft.state_dict(), model_path)
else:
torch.save(tft.state_dict(), model_path)
with lock:
torch.save(tft.state_dict(), model_path)
msg = {
"metrics": [target_column],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment