# Copyright 2019 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Helper functions for loading files."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os.path
import pandas as pd


def filename(env_name, noops, dev_measure, dev_fun, baseline, beta,
             value_discount, seed, path='', suffix=''):
  """Generate filename for the given set of parameters."""
  noop_str = 'noops' if noops else 'nonoops'
  seed_str = '_' + str(seed) if seed else ''
  filename_template = ('{env_name}_{noop_str}_{dev_measure}_{dev_fun}' +
                       '_{baseline}_beta_{beta}_vd_{value_discount}' +
                       '{suffix}{seed_str}.csv')
  full_path = os.path.join(path, filename_template.format(
      env_name=env_name, noop_str=noop_str, dev_measure=dev_measure,
      dev_fun=dev_fun, baseline=baseline, beta=beta,
      value_discount=value_discount, suffix=suffix, seed_str=seed_str))
  return full_path


def load_files(baseline, dev_measure, dev_fun, value_discount, beta, env_name,
               noops, path, suffix, seed_list, final=True):
  """Load result files generated by run_experiment with the given parameters."""
  def try_loading(f, final):
    if os.path.isfile(f):
      df = pd.read_csv(f, index_col=0)
      if final:
        last_episode = max(df['episode'])
        return df[df.episode == last_episode]
      else:
        return df
    else:
      return pd.DataFrame()
  dataframes = []
  for seed in seed_list:
    f = filename(baseline=baseline, dev_measure=dev_measure, dev_fun=dev_fun,
                 value_discount=value_discount, beta=beta, env_name=env_name,
                 noops=noops, path=path, suffix=suffix, seed=int(seed))
    df_part = try_loading(f, final)
    dataframes.append(df_part)
  df = pd.concat(dataframes)
  return df