개발/머신러닝, 딥러닝
[딥러닝] 텐서플로우 데이터셋(tensorflow_datasets) 활용
토마토Tomato
2021. 2. 4. 19:53
import tensorflow_datasets as tfds
train_dataset = tfds.load('iris', split='train[:80%]')
valid_dataset = tfds.load('iris', split='train[80%:]')
- 'iris': 불러오고자 하는 데이터 셋 이름
- split='train': 'train' set을 가져온다
- 만약 해당 데이터셋에 'test'와 'validation'이 있을 경우 불어올 수 있다.
- 'train[:80%]'는 80%까지 가져오겠다는 말
- 파이썬 문법처럼 음수도 사용가능. (e.g. 'train[-20%:]')
데이터 확인
train_dataset.take(5)
처음 데이터 5개를 리스트(?) 형식으로 반환해준다.
X, y 분리 및 one hot encoding으로 변환 (preprocessing)
dataset 정보는 텐서플로우 데이터셋 공식문서에서 확인(label class 종류, dictionary label 값 등 확인할 수 있음)
def proprocess(data):
x = data['features']
y = data['label']
y = tf.one_hot(y, 3)
return x, y
tf.one_hot: label을 one-hot encoding으로 변환해줌. 3은 label의 개수(공식문서 num_classes에서 확인하면 됨)