summaryrefslogtreecommitdiff
path: root/database/gwf_converter
diff options
context:
space:
mode:
Diffstat (limited to 'database/gwf_converter')
-rw-r--r--database/gwf_converter/gwf_converter.py44
1 files changed, 30 insertions, 14 deletions
diff --git a/database/gwf_converter/gwf_converter.py b/database/gwf_converter/gwf_converter.py
index 81de2440..902bd93f 100644
--- a/database/gwf_converter/gwf_converter.py
+++ b/database/gwf_converter/gwf_converter.py
@@ -17,7 +17,8 @@ class Task:
self.job = job
self.submit_time = submit_time
self.run_time = run_time
- self.flops = 10 ** 9 * run_time * num_processors
+ self.cores = num_processors
+ self.flops = 4000 * run_time * num_processors
self.dependency_gwf_ids = dependency_gwf_ids
self.db_id = -1
self.dependencies = []
@@ -55,8 +56,7 @@ def get_jobs_from_gwf_file(file_name):
return jobs.values()
-def write_to_db(trace_name, jobs):
- conn = mariadb.connect(user='opendc', password='opendcpassword', database='opendc')
+def write_to_db(conn, trace_name, jobs):
cursor = conn.cursor()
trace_id = execute_insert_query(conn, cursor, "INSERT INTO traces (name) VALUES ('%s')" % trace_name)
@@ -66,9 +66,10 @@ def write_to_db(trace_name, jobs):
% ("Job %d" % job.gwf_id, trace_id))
for task in job.tasks:
- task.db_id = execute_insert_query(conn, cursor, "INSERT INTO tasks (start_tick, total_flop_count, job_id, "
- "parallelizability) VALUES (%d,%d,%d,'PARALLEL')"
- % (task.submit_time, task.flops, job.db_id))
+ task.db_id = execute_insert_query(conn, cursor,
+ "INSERT INTO tasks (start_tick, total_flop_count, core_count, job_id) "
+ "VALUES (%d,%d,%d,%d)"
+ % (task.submit_time, task.flops, task.cores, job.db_id))
for job in jobs:
for task in job.tasks:
@@ -77,9 +78,6 @@ def write_to_db(trace_name, jobs):
"VALUES (%d,%d)"
% (dependency.db_id, task.db_id))
- conn.close()
-
-
def execute_insert_query(conn, cursor, sql):
try:
cursor.execute(sql)
@@ -90,10 +88,28 @@ def execute_insert_query(conn, cursor, sql):
return cursor.lastrowid
+def main(trace_path):
+ trace_name = sys.argv[2] if (len(sys.argv) > 2) else \
+ os.path.splitext(os.path.basename(trace_path))[0]
+ gwf_jobs = get_jobs_from_gwf_file(trace_path)
+
+ host = os.environ.get('PERSISTENCE_HOST','localhost')
+ user = os.environ.get('PERSISTENCE_USER','opendc')
+ password = os.environ.get('PERSISTENCE_PASSWORD','opendcpassword')
+ database = os.environ.get('PERSISTENCE_DATABASE','opendc')
+ conn = mariadb.connect(host=host, user=user, password=password, database=database)
+ write_to_db(conn, trace_name, gwf_jobs)
+ conn.close()
+
+
if __name__ == "__main__":
if len(sys.argv) < 2:
- sys.exit("Usage: %s trace-name" % sys.argv[0])
-
- gwf_trace_name = sys.argv[1]
- gwf_jobs = get_jobs_from_gwf_file(os.path.join("traces", gwf_trace_name + ".gwf"))
- write_to_db(gwf_trace_name, gwf_jobs)
+ sys.exit("Usage: %s file [name]" % sys.argv[0])
+
+ if sys.argv[1] in ("-a", "--all"):
+ for f in os.listdir("traces"):
+ if f.endswith(".gwf"):
+ print("Converting {}".format(f))
+ main(os.path.join("traces", f))
+ else:
+ main(sys.argv[1])