diff --git a/bashplotlib/scatterplot.py b/bashplotlib/scatterplot.py index 76d20e4..2523950 100644 --- a/bashplotlib/scatterplot.py +++ b/bashplotlib/scatterplot.py @@ -7,6 +7,7 @@ from __future__ import print_function import csv +import os import sys import optparse from .utils.helpers import * @@ -62,6 +63,7 @@ def plot_scatter(f, xs, ys, size, pch, colour, title): colour -- colour of the points title -- title of the plot """ + path_allowed_types = (str, bytes, os.PathLike) cs = None if f: if isinstance(f, str): @@ -73,13 +75,17 @@ def plot_scatter(f, xs, ys, size, pch, colour, title): ys = [float(i[1]) for i in data] if len(data[0]) > 2: cs = [i[2].strip() for i in data] - elif isinstance(xs, list) and isinstance(ys, list): - pass - else: + # try to convert any iterable data to list, so we can use any iterable object like pandas dataframe or numpy array + elif type(xs) in path_allowed_types and type(ys) in path_allowed_types: with open(xs) as fh: xs = [float(str(row).strip()) for row in fh] with open(ys) as fh: ys = [float(str(row).strip()) for row in fh] + elif isiterable(xs) and isiterable(ys): + xs = [i for i in xs] + ys = [i for i in ys] + else: + raise ValueError("Invalid data types {} or {} must be iterable, str, or pathlike".format(type(xs), type(ys))) _plot_scatter(xs, ys, size, pch, colour, title, cs)